postgresql_adapter.rb 43.9 KB
Newer Older
D
Initial  
David Heinemeier Hansson 已提交
1
require 'active_record/connection_adapters/abstract_adapter'
2
require 'active_support/core_ext/object/blank'
3
require 'active_record/connection_adapters/statement_pool'
4 5 6

# Make sure we're using pg high enough for PGResult#values
gem 'pg', '~> 0.11'
7
require 'pg'
D
Initial  
David Heinemeier Hansson 已提交
8 9 10 11 12

module ActiveRecord
  class Base
    # Establishes a connection to the database that's used by all Active Record objects
    def self.postgresql_connection(config) # :nodoc:
13
      config = config.symbolize_keys
D
Initial  
David Heinemeier Hansson 已提交
14
      host     = config[:host]
15
      port     = config[:port] || 5432
16 17
      username = config[:username].to_s if config[:username]
      password = config[:password].to_s if config[:password]
D
Initial  
David Heinemeier Hansson 已提交
18

19
      if config.key?(:database)
D
Initial  
David Heinemeier Hansson 已提交
20 21 22 23 24
        database = config[:database]
      else
        raise ArgumentError, "No database specified. Missing argument: database."
      end

25
      # The postgres drivers don't allow the creation of an unconnected PGconn object,
26 27 28 29
      # so just pass a nil connection object for the time being.
      ConnectionAdapters::PostgreSQLAdapter.new(nil, logger, [host, port, nil, nil, database, username, password], config)
    end
  end
30

31 32 33 34 35 36 37
  module ConnectionAdapters
    # PostgreSQL-specific extensions to column definitions in a table.
    class PostgreSQLColumn < Column #:nodoc:
      # Instantiates a new PostgreSQL column definition in a table.
      def initialize(name, default, sql_type = nil, null = true)
        super(name, self.class.extract_value_from_default(default), sql_type, null)
      end
38

39 40 41
      # :stopdoc:
      class << self
        attr_accessor :money_precision
42 43 44 45 46 47 48 49 50 51
        def string_to_time(string)
          return string unless String === string

          case string
          when 'infinity'  then 1.0 / 0.0
          when '-infinity' then -1.0 / 0.0
          else
            super
          end
        end
52

53 54
        def cast_hstore(object)
          if Hash === object
A
Aaron Patterson 已提交
55 56 57
            object.map { |k,v|
              "#{escape_hstore(k)}=>#{escape_hstore(v)}"
            }.join ', '
58
          else
A
Aaron Patterson 已提交
59 60
            kvs = object.scan(/(?<!\\)".*?(?<!\\)"/).map { |o|
              unescape_hstore(o[1...-1])
61
            }
A
Aaron Patterson 已提交
62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87
            Hash[kvs.each_slice(2).to_a]
          end
        end

        private
        def unescape_hstore(value)
          escape_values = {
            '\\ '  => ' ',
            '\\\\' => '\\',
            '\\"'  => '"',
            '\\='  => '=',
          }
          value.gsub(Regexp.union(escape_values.keys)) do |match|
            escape_values[match]
          end
        end

        def escape_hstore(value)
          escape_values = {
            ' '  => '\\ ',
            '\\' => '\\\\',
            '"'  => '\\"',
            '='  => '\\=',
          }
          value.gsub(Regexp.union(escape_values.keys)) do |match|
            escape_values[match]
88
          end
89
        end
90 91 92
      end
      # :startdoc:

93
      private
94
        def extract_limit(sql_type)
95 96 97 98 99
          case sql_type
          when /^bigint/i;    8
          when /^smallint/i;  2
          else super
          end
100 101
        end

102 103 104 105 106
        # Extracts the scale from PostgreSQL-specific data types.
        def extract_scale(sql_type)
          # Money type has a fixed scale of 2.
          sql_type =~ /^money/ ? 2 : super
        end
107

108 109
        # Extracts the precision from PostgreSQL-specific data types.
        def extract_precision(sql_type)
110 111 112 113 114
          if sql_type == 'money'
            self.class.money_precision
          else
            super
          end
115
        end
116

117 118 119
        # Maps PostgreSQL-specific data types to logical Rails types.
        def simplified_type(field_type)
          case field_type
120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168
          # Numeric and monetary types
          when /^(?:real|double precision)$/
            :float
          # Monetary types
          when 'money'
            :decimal
          when 'hstore'
            :hstore
          # Character types
          when /^(?:character varying|bpchar)(?:\(\d+\))?$/
            :string
          # Binary data types
          when 'bytea'
            :binary
          # Date/time types
          when /^timestamp with(?:out)? time zone$/
            :datetime
          when 'interval'
            :string
          # Geometric types
          when /^(?:point|line|lseg|box|"?path"?|polygon|circle)$/
            :string
          # Network address types
          when /^(?:cidr|inet|macaddr)$/
            :string
          # Bit strings
          when /^bit(?: varying)?(?:\(\d+\))?$/
            :string
          # XML type
          when 'xml'
            :xml
          # tsvector type
          when 'tsvector'
            :tsvector
          # Arrays
          when /^\D+\[\]$/
            :string
          # Object identifier types
          when 'oid'
            :integer
          # UUID type
          when 'uuid'
            :string
          # Small and big integer types
          when /^(?:small|big)int$/
            :integer
          # Pass through all types that are not specific to PostgreSQL.
          else
            super
169 170
          end
        end
171

172 173 174
        # Extracts the value from a PostgreSQL column default definition.
        def self.extract_value_from_default(default)
          case default
175 176 177 178 179 180 181 182
            # This is a performance optimization for Ruby 1.9.2 in development.
            # If the value is nil, we return nil straight away without checking
            # the regular expressions. If we check each regular expression,
            # Regexp#=== will call NilClass#to_str, which will trigger
            # method_missing (defined by whiny nil in ActiveSupport) which
            # makes this method very very slow.
            when NilClass
              nil
183
            # Numeric types
184 185
            when /\A\(?(-?\d+(\.\d*)?\)?)\z/
              $1
186
            # Character types
187
            when /\A'(.*)'::(?:character varying|bpchar|text)\z/m
188
              $1
189 190 191
            # Character types (8.1 formatting)
            when /\AE'(.*)'::(?:character varying|bpchar|text)\z/m
              $1.gsub(/\\(\d\d\d)/) { $1.oct.chr }
192
            # Binary data types
193
            when /\A'(.*)'::bytea\z/m
194 195
              $1
            # Date/time types
196
            when /\A'(.+)'::(?:time(?:stamp)? with(?:out)? time zone|date)\z/
197
              $1
198
            when /\A'(.*)'::interval\z/
199 200
              $1
            # Boolean type
201
            when 'true'
202
              true
203
            when 'false'
204 205
              false
            # Geometric types
206
            when /\A'(.*)'::(?:point|line|lseg|box|"?path"?|polygon|circle)\z/
207 208
              $1
            # Network address types
209
            when /\A'(.*)'::(?:cidr|inet|macaddr)\z/
210 211
              $1
            # Bit string types
212
            when /\AB'(.*)'::"?bit(?: varying)?"?\z/
213 214
              $1
            # XML type
215
            when /\A'(.*)'::xml\z/m
216 217
              $1
            # Arrays
218
            when /\A'(.*)'::"?\D+"?\[\]\z/
219 220
              $1
            # Object identifier types
221
            when /\A-?\d+\z/
222 223 224
              $1
            else
              # Anything else is blank, some user type, or some function
225
              # and we can't know the value of that, so return nil.
226 227 228
              nil
          end
        end
D
Initial  
David Heinemeier Hansson 已提交
229 230
    end

231 232
    # The PostgreSQL adapter works both with the native C (http://ruby.scripting.ca/postgres/) and the pure
    # Ruby (available both as gem and from http://rubyforge.org/frs/?group_id=234&release_id=1944) drivers.
233 234 235
    #
    # Options:
    #
P
Pratik Naik 已提交
236 237 238 239 240
    # * <tt>:host</tt> - Defaults to "localhost".
    # * <tt>:port</tt> - Defaults to 5432.
    # * <tt>:username</tt> - Defaults to nothing.
    # * <tt>:password</tt> - Defaults to nothing.
    # * <tt>:database</tt> - The name of the database. No default, must be provided.
241
    # * <tt>:schema_search_path</tt> - An optional schema search path for the connection given
242
    #   as a string of comma-separated schema names. This is backward-compatible with the <tt>:schema_order</tt> option.
243
    # * <tt>:encoding</tt> - An optional client encoding that is used in a <tt>SET client_encoding TO
244
    #   <encoding></tt> call on the connection.
245
    # * <tt>:min_messages</tt> - An optional client min messages that is used in a
246
    #   <tt>SET client_min_messages TO <min_messages></tt> call on the connection.
247
    class PostgreSQLAdapter < AbstractAdapter
248 249 250 251 252
      class TableDefinition < ActiveRecord::ConnectionAdapters::TableDefinition
        def xml(*args)
          options = args.extract_options!
          column(args[0], 'xml', options)
        end
253 254 255 256 257

        def tsvector(*args)
          options = args.extract_options!
          column(args[0], 'tsvector', options)
        end
258 259 260 261

        def hstore(name, options = {})
          column(name, 'hstore', options)
        end
262 263
      end

264
      ADAPTER_NAME = 'PostgreSQL'
265 266

      NATIVE_DATABASE_TYPES = {
267
        :primary_key => "serial primary key",
268 269 270 271 272 273 274 275 276 277
        :string      => { :name => "character varying", :limit => 255 },
        :text        => { :name => "text" },
        :integer     => { :name => "integer" },
        :float       => { :name => "float" },
        :decimal     => { :name => "decimal" },
        :datetime    => { :name => "timestamp" },
        :timestamp   => { :name => "timestamp" },
        :time        => { :name => "time" },
        :date        => { :name => "date" },
        :binary      => { :name => "bytea" },
278
        :boolean     => { :name => "boolean" },
279 280
        :xml         => { :name => "xml" },
        :tsvector    => { :name => "tsvector" }
281 282
      }

283
      # Returns 'PostgreSQL' as adapter name for identification purposes.
284
      def adapter_name
285
        ADAPTER_NAME
286 287
      end

288 289
      # Returns +true+, since this connection adapter supports prepared statement
      # caching.
290 291 292 293
      def supports_statement_cache?
        true
      end

294 295 296 297
      def supports_index_sort_order?
        true
      end

298 299 300 301
      class StatementPool < ConnectionAdapters::StatementPool
        def initialize(connection, max)
          super
          @counter = 0
302
          @cache   = Hash.new { |h,pid| h[pid] = {} }
303 304
        end

305 306 307 308
        def each(&block); cache.each(&block); end
        def key?(key);    cache.key?(key); end
        def [](key);      cache[key]; end
        def length;       cache.length; end
309 310 311 312 313 314

        def next_key
          "a#{@counter + 1}"
        end

        def []=(sql, key)
315 316
          while @max <= cache.size
            dealloc(cache.shift.last)
317 318
          end
          @counter += 1
319
          cache[sql] = key
320 321 322
        end

        def clear
323
          cache.each_value do |stmt_key|
324 325
            dealloc stmt_key
          end
326
          cache.clear
327 328
        end

329 330 331 332 333
        def delete(sql_key)
          dealloc cache[sql_key]
          cache.delete sql_key
        end

334
        private
335 336 337 338
        def cache
          @cache[$$]
        end

339
        def dealloc(key)
340 341 342 343 344 345 346
          @connection.query "DEALLOCATE #{key}" if connection_active?
        end

        def connection_active?
          @connection.status == PGconn::CONNECTION_OK
        rescue PGError
          false
347 348 349
        end
      end

350 351
      # Initializes and connects a PostgreSQL adapter.
      def initialize(connection, logger, connection_parameters, config)
352
        super(connection, logger)
353
        @connection_parameters, @config = connection_parameters, config
354
        @visitor = Arel::Visitors::PostgreSQL.new self
355

356 357
        # @local_tz is initialized as nil to avoid warnings when connect tries to use it
        @local_tz = nil
358 359
        @table_alias_length = nil

360
        connect
361 362
        @statements = StatementPool.new @connection,
                                        config.fetch(:statement_limit) { 1000 }
363 364 365 366 367

        if postgresql_version < 80200
          raise "Your version of PostgreSQL (#{postgresql_version}) is too old, please upgrade!"
        end

368
        @local_tz = execute('SHOW TIME ZONE', 'SCHEMA').first["TimeZone"]
369 370
      end

X
Xavier Noria 已提交
371
      # Clears the prepared statements cache.
372 373 374 375
      def clear_cache!
        @statements.clear
      end

376 377
      # Is this connection alive and ready for queries?
      def active?
378 379
        @connection.status == PGconn::CONNECTION_OK
      rescue PGError
380
        false
381 382 383 384
      end

      # Close then reopen the connection.
      def reconnect!
385 386 387
        clear_cache!
        @connection.reset
        configure_connection
388
      end
389

390 391 392 393 394
      def reset!
        clear_cache!
        super
      end

395 396
      # Disconnects from the database if already connected. Otherwise, this
      # method does nothing.
397
      def disconnect!
398
        clear_cache!
399 400
        @connection.close rescue nil
      end
401

402
      def native_database_types #:nodoc:
403
        NATIVE_DATABASE_TYPES
404
      end
405

406
      # Returns true, since this connection adapter supports migrations.
407 408
      def supports_migrations?
        true
409 410
      end

411
      # Does PostgreSQL support finding primary key on non-Active Record tables?
412 413 414 415
      def supports_primary_key? #:nodoc:
        true
      end

416 417 418
      # Enable standard-conforming strings if available.
      def set_standard_conforming_strings
        old, self.client_min_messages = client_min_messages, 'panic'
419
        execute('SET standard_conforming_strings = on', 'SCHEMA') rescue nil
420 421
      ensure
        self.client_min_messages = old
422 423
      end

424
      def supports_insert_with_returning?
425
        true
426 427
      end

428 429 430
      def supports_ddl_transactions?
        true
      end
431

432
      # Returns true, since this connection adapter supports savepoints.
433 434 435
      def supports_savepoints?
        true
      end
436

437 438 439 440 441
      # Returns true.
      def supports_explain?
        true
      end

442
      # Returns the configured supported identifier length supported by PostgreSQL
443
      def table_alias_length
444
        @table_alias_length ||= query('SHOW max_identifier_length')[0][0].to_i
445
      end
446

447 448
      # QUOTING ==================================================

449
      # Escapes binary strings for bytea input to the database.
450 451
      def escape_bytea(value)
        @connection.escape_bytea(value) if value
452 453 454 455 456
      end

      # Unescapes bytea output from a database to the binary string it represents.
      # NOTE: This is NOT an inverse of escape_bytea! This is only to be used
      #       on escaped binary output from database drive.
457 458
      def unescape_bytea(value)
        @connection.unescape_bytea(value) if value
459 460
      end

461 462
      # Quotes PostgreSQL-specific data types for SQL input.
      def quote(value, column = nil) #:nodoc:
463 464
        return super unless column

A
Aaron Patterson 已提交
465
        case value
466 467 468
        when Float
          return super unless value.infinite? && column.type == :datetime
          "'#{value.to_s.downcase}'"
A
Aaron Patterson 已提交
469 470
        when Numeric
          return super unless column.sql_type == 'money'
471
          # Not truly string input, so doesn't require (or allow) escape string syntax.
472
          "'#{value}'"
A
Aaron Patterson 已提交
473 474 475 476 477 478 479 480 481 482 483
        when String
          case column.sql_type
          when 'bytea' then "'#{escape_bytea(value)}'"
          when 'xml'   then "xml '#{quote_string(value)}'"
          when /^bit/
            case value
            when /^[01]*$/      then "B'#{value}'" # Bit-string notation
            when /^[0-9A-F]*$/i then "X'#{value}'" # Hexadecimal notation
            end
          else
            super
484
          end
485 486 487 488 489
        else
          super
        end
      end

490 491 492 493 494 495
      def type_cast(value, column)
        return super unless column

        case value
        when String
          return super unless 'bytea' == column.sql_type
496
          { :value => value, :format => 1 }
497 498 499 500 501
        else
          super
        end
      end

502 503 504
      # Quotes strings for use in SQL input.
      def quote_string(s) #:nodoc:
        @connection.escape(s)
505 506
      end

507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525
      # Checks the following cases:
      #
      # - table_name
      # - "table.name"
      # - schema_name.table_name
      # - schema_name."table.name"
      # - "schema.name".table_name
      # - "schema.name"."table.name"
      def quote_table_name(name)
        schema, name_part = extract_pg_identifier_from_name(name.to_s)

        unless name_part
          quote_column_name(schema)
        else
          table_name, name_part = extract_pg_identifier_from_name(name_part)
          "#{quote_column_name(schema)}.#{quote_column_name(table_name)}"
        end
      end

526 527
      # Quotes column names for use in SQL queries.
      def quote_column_name(name) #:nodoc:
528
        PGconn.quote_ident(name.to_s)
529 530
      end

531 532 533
      # Quote date/time values for use in SQL input. Includes microseconds
      # if the value is a Time responding to usec.
      def quoted_date(value) #:nodoc:
534 535 536 537 538
        if value.acts_like?(:time) && value.respond_to?(:usec)
          "#{super}.#{sprintf("%06d", value.usec)}"
        else
          super
        end
539 540
      end

541 542
      # Set the authorized user for this session
      def session_auth=(user)
543
        clear_cache!
A
Aaron Patterson 已提交
544
        exec_query "SET SESSION AUTHORIZATION #{user}"
545 546
      end

547 548
      # REFERENTIAL INTEGRITY ====================================

549
      def supports_disable_referential_integrity? #:nodoc:
550
        true
551 552
      end

553
      def disable_referential_integrity #:nodoc:
554
        if supports_disable_referential_integrity? then
555 556
          execute(tables.collect { |name| "ALTER TABLE #{quote_table_name(name)} DISABLE TRIGGER ALL" }.join(";"))
        end
557 558
        yield
      ensure
559
        if supports_disable_referential_integrity? then
560 561
          execute(tables.collect { |name| "ALTER TABLE #{quote_table_name(name)} ENABLE TRIGGER ALL" }.join(";"))
        end
562
      end
563 564 565

      # DATABASE STATEMENTS ======================================

566
      def explain(arel, binds = [])
X
Xavier Noria 已提交
567
        sql = "EXPLAIN #{to_sql(arel)}"
568
        ExplainPrettyPrinter.new.pp(exec_query(sql, 'EXPLAIN', binds))
X
Xavier Noria 已提交
569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607
      end

      class ExplainPrettyPrinter # :nodoc:
        # Pretty prints the result of a EXPLAIN in a way that resembles the output of the
        # PostgreSQL shell:
        #
        #                                     QUERY PLAN
        #   ------------------------------------------------------------------------------
        #    Nested Loop Left Join  (cost=0.00..37.24 rows=8 width=0)
        #      Join Filter: (posts.user_id = users.id)
        #      ->  Index Scan using users_pkey on users  (cost=0.00..8.27 rows=1 width=4)
        #            Index Cond: (id = 1)
        #      ->  Seq Scan on posts  (cost=0.00..28.88 rows=8 width=4)
        #            Filter: (posts.user_id = 1)
        #   (6 rows)
        #
        def pp(result)
          header = result.columns.first
          lines  = result.rows.map(&:first)

          # We add 2 because there's one char of padding at both sides, note
          # the extra hyphens in the example above.
          width = [header, *lines].map(&:length).max + 2

          pp = []

          pp << header.center(width).rstrip
          pp << '-' * width

          pp += lines.map {|line| " #{line}"}

          nrows = result.rows.length
          rows_label = nrows == 1 ? 'row' : 'rows'
          pp << "(#{nrows} #{rows_label})"

          pp.join("\n") + "\n"
        end
      end

608 609 610 611 612 613
      # Executes a SELECT query and returns an array of rows. Each row is an
      # array of field values.
      def select_rows(sql, name = nil)
        select_raw(sql, name).last
      end

614
      # Executes an INSERT query and returns the new record's ID
615
      def insert_sql(sql, name = nil, pk = nil, id_value = nil, sequence_name = nil)
616 617 618 619 620
        unless pk
          # Extract the table from the insert sql. Yuck.
          table_ref = extract_table_ref_from_insert_sql(sql)
          pk = primary_key(table_ref) if table_ref
        end
621

622
        if pk
623 624 625
          select_value("#{sql} RETURNING #{quote_column_name(pk)}")
        else
          super
626
        end
627
      end
628
      alias :create :insert
629

630 631
      # create a 2D array representing the result set
      def result_as_array(res) #:nodoc:
632
        # check if we have any binary column and if they need escaping
633 634
        ftypes = Array.new(res.nfields) do |i|
          [i, res.ftype(i)]
635 636
        end

637 638 639 640 641 642
        rows = res.values
        return rows unless ftypes.any? { |_, x|
          x == BYTEA_COLUMN_TYPE_OID || x == MONEY_COLUMN_TYPE_OID
        }

        typehash = ftypes.group_by { |_, type| type }
643 644
        binaries = typehash[BYTEA_COLUMN_TYPE_OID] || []
        monies   = typehash[MONEY_COLUMN_TYPE_OID] || []
645 646 647

        rows.each do |row|
          # unescape string passed BYTEA field (OID == 17)
648 649
          binaries.each do |index, _|
            row[index] = unescape_bytea(row[index])
650 651 652 653 654 655
          end

          # If this is a money type column and there are any currency symbols,
          # then strip them off. Indeed it would be prettier to do this in
          # PostgreSQLColumn.string_to_decimal but would break form input
          # fields that call value_before_type_cast.
656
          monies.each do |index, _|
657 658 659 660 661 662 663 664 665 666
            data = row[index]
            # Because money output is formatted according to the locale, there are two
            # cases to consider (note the decimal separators):
            #  (1) $12,345,678.12
            #  (2) $12.345.678,12
            case data
            when /^-?\D+[\d,]+\.\d{2}$/  # (1)
              data.gsub!(/[^-\d.]/, '')
            when /^-?\D+[\d.]+,\d{2}$/  # (2)
              data.gsub!(/[^-\d,]/, '').sub!(/,/, '.')
667
            end
668 669 670 671 672 673
          end
        end
      end


      # Queries the database and returns the results in an Array-like object
674
      def query(sql, name = nil) #:nodoc:
675
        log(sql, name) do
676
          result_as_array @connection.async_exec(sql)
677
        end
678 679
      end

680
      # Executes an SQL statement, returning a PGresult object on success
681 682
      # or raising a PGError exception otherwise.
      def execute(sql, name = nil)
683
        log(sql, name) do
684
          @connection.async_exec(sql)
685
        end
686 687
      end

688 689
      def substitute_at(column, index)
        Arel.sql("$#{index + 1}")
690 691
      end

A
Aaron Patterson 已提交
692
      def exec_query(sql, name = 'SQL', binds = [])
693
        log(sql, name, binds) do
694 695
          result = binds.empty? ? exec_no_cache(sql, binds) :
                                  exec_cache(sql, binds)
696

697 698 699
          ret = ActiveRecord::Result.new(result.fields, result_as_array(result))
          result.clear
          return ret
700 701 702
        end
      end

703 704 705 706 707 708 709 710 711
      def exec_delete(sql, name = 'SQL', binds = [])
        log(sql, name, binds) do
          result = binds.empty? ? exec_no_cache(sql, binds) :
                                  exec_cache(sql, binds)
          affected = result.cmd_tuples
          result.clear
          affected
        end
      end
712
      alias :exec_update :exec_delete
713

714 715
      def sql_for_insert(sql, pk, id_value, sequence_name, binds)
        unless pk
716 717 718
          # Extract the table from the insert sql. Yuck.
          table_ref = extract_table_ref_from_insert_sql(sql)
          pk = primary_key(table_ref) if table_ref
719 720 721 722 723 724 725
        end

        sql = "#{sql} RETURNING #{quote_column_name(pk)}" if pk

        [sql, binds]
      end

726
      # Executes an UPDATE query and returns the number of affected tuples.
727
      def update_sql(sql, name = nil)
728
        super.cmd_tuples
729 730
      end

731 732
      # Begins a transaction.
      def begin_db_transaction
733 734 735
        execute "BEGIN"
      end

736 737
      # Commits a transaction.
      def commit_db_transaction
738 739
        execute "COMMIT"
      end
740

741 742
      # Aborts a transaction.
      def rollback_db_transaction
743 744
        execute "ROLLBACK"
      end
745

746 747
      def outside_transaction?
        @connection.transaction_status == PGconn::PQTRANS_IDLE
748
      end
749

J
Jonathan Viney 已提交
750 751 752 753 754 755 756 757
      def create_savepoint
        execute("SAVEPOINT #{current_savepoint_name}")
      end

      def rollback_to_savepoint
        execute("ROLLBACK TO SAVEPOINT #{current_savepoint_name}")
      end

758
      def release_savepoint
J
Jonathan Viney 已提交
759 760
        execute("RELEASE SAVEPOINT #{current_savepoint_name}")
      end
761

762 763
      # SCHEMA STATEMENTS ========================================

764 765 766
      # Drops the database specified on the +name+ attribute
      # and creates it again using the provided +options+.
      def recreate_database(name, options = {}) #:nodoc:
767
        drop_database(name)
768
        create_database(name, options)
769 770
      end

771
      # Create a new PostgreSQL database. Options include <tt>:owner</tt>, <tt>:template</tt>,
772 773
      # <tt>:encoding</tt>, <tt>:tablespace</tt>, and <tt>:connection_limit</tt> (note that MySQL uses
      # <tt>:charset</tt> while PostgreSQL uses <tt>:encoding</tt>).
774 775 776 777 778 779 780 781 782 783
      #
      # Example:
      #   create_database config[:database], config
      #   create_database 'foo_development', :encoding => 'unicode'
      def create_database(name, options = {})
        options = options.reverse_merge(:encoding => "utf8")

        option_string = options.symbolize_keys.sum do |key, value|
          case key
          when :owner
784
            " OWNER = \"#{value}\""
785
          when :template
786
            " TEMPLATE = \"#{value}\""
787 788 789
          when :encoding
            " ENCODING = '#{value}'"
          when :tablespace
790
            " TABLESPACE = \"#{value}\""
791 792 793 794 795 796 797
          when :connection_limit
            " CONNECTION LIMIT = #{value}"
          else
            ""
          end
        end

798
        execute "CREATE DATABASE #{quote_table_name(name)}#{option_string}"
799 800
      end

801
      # Drops a PostgreSQL database.
802 803 804 805
      #
      # Example:
      #   drop_database 'matt_development'
      def drop_database(name) #:nodoc:
806
        execute "DROP DATABASE IF EXISTS #{quote_table_name(name)}"
807 808
      end

809 810
      # Returns the list of all tables in the schema search path or a specified schema.
      def tables(name = nil)
811
        query(<<-SQL, 'SCHEMA').map { |row| row[0] }
812
          SELECT tablename
813 814 815 816 817
          FROM pg_tables
          WHERE schemaname = ANY (current_schemas(false))
        SQL
      end

818
      # Returns true if table exists.
819 820
      # If the schema is not specified as part of +name+ then it will only find tables within
      # the current schema search path (regardless of permissions to access tables in other schemas)
821
      def table_exists?(name)
822
        schema, table = Utils.extract_schema_and_table(name.to_s)
823
        return false unless table
824

825 826
        binds = [[nil, table]]
        binds << [nil, schema] if schema
827 828

        exec_query(<<-SQL, 'SCHEMA', binds).rows.first[0].to_i > 0
829
            SELECT COUNT(*)
A
Aaron Patterson 已提交
830 831 832 833 834
            FROM pg_class c
            LEFT JOIN pg_namespace n ON n.oid = c.relnamespace
            WHERE c.relkind in ('v','r')
            AND c.relname = $1
            AND n.nspname = #{schema ? '$2' : 'ANY (current_schemas(false))'}
835 836 837
        SQL
      end

838 839 840 841 842 843 844 845
      # Returns true if schema exists.
      def schema_exists?(name)
        exec_query(<<-SQL, 'SCHEMA', [[nil, name]]).rows.first[0].to_i > 0
          SELECT COUNT(*)
          FROM pg_namespace
          WHERE nspname = $1
        SQL
      end
846

847
      # Returns an array of indexes for the given table.
848
      def indexes(table_name, name = nil)
849
         result = query(<<-SQL, name)
850
           SELECT distinct i.relname, d.indisunique, d.indkey, pg_get_indexdef(d.indexrelid), t.oid
851 852 853
           FROM pg_class t
           INNER JOIN pg_index d ON t.oid = d.indrelid
           INNER JOIN pg_class i ON d.indexrelid = i.oid
854 855 856
           WHERE i.relkind = 'i'
             AND d.indisprimary = 'f'
             AND t.relname = '#{table_name}'
857
             AND i.relnamespace IN (SELECT oid FROM pg_namespace WHERE nspname = ANY (current_schemas(false)) )
858 859 860
          ORDER BY i.relname
        SQL

861

862
        result.map do |row|
863 864 865
          index_name = row[0]
          unique = row[1] == 't'
          indkey = row[2].split(" ")
866 867
          inddef = row[3]
          oid = row[4]
868

869 870
          columns = Hash[query(<<-SQL, "Columns for index #{row[0]} on #{table_name}")]
          SELECT a.attnum, a.attname
871 872 873 874 875
          FROM pg_attribute a
          WHERE a.attrelid = #{oid}
          AND a.attnum IN (#{indkey.join(",")})
          SQL

876
          column_names = columns.values_at(*indkey).compact
877 878 879 880 881 882

          # add info on sort order for columns (only desc order is explicitly specified, asc is the default)
          desc_order_columns = inddef.scan(/(\w+) DESC/).flatten
          orders = desc_order_columns.any? ? Hash[desc_order_columns.map {|order_column| [order_column, :desc]}] : {}
      
          column_names.empty? ? nil : IndexDefinition.new(table_name, index_name, unique, column_names, [], orders)
883
        end.compact
884 885
      end

886 887
      # Returns the list of all column definitions for a table.
      def columns(table_name, name = nil)
888
        # Limit, precision, and scale are all handled by the superclass.
889 890
        column_definitions(table_name).collect do |column_name, type, default, notnull|
          PostgreSQLColumn.new(column_name, default, type, notnull == 'f')
D
Initial  
David Heinemeier Hansson 已提交
891 892 893
        end
      end

894 895 896 897 898
      # Returns the current database name.
      def current_database
        query('select current_database()')[0][0]
      end

899 900 901 902 903
      # Returns the current schema name.
      def current_schema
        query('SELECT current_schema', 'SCHEMA')[0][0]
      end

904 905 906 907 908 909 910 911
      # Returns the current database encoding format.
      def encoding
        query(<<-end_sql)[0][0]
          SELECT pg_encoding_to_char(pg_database.encoding) FROM pg_database
          WHERE pg_database.datname LIKE '#{current_database}'
        end_sql
      end

912 913 914 915 916 917
      # Sets the schema search path to a string of comma-separated schema names.
      # Names beginning with $ have to be quoted (e.g. $user => '$user').
      # See: http://www.postgresql.org/docs/current/static/ddl-schemas.html
      #
      # This should be not be called manually but set in database.yml.
      def schema_search_path=(schema_csv)
918 919
        if schema_csv
          execute "SET search_path TO #{schema_csv}"
920
          @schema_search_path = schema_csv
921
        end
D
Initial  
David Heinemeier Hansson 已提交
922 923
      end

924 925
      # Returns the active schema search path.
      def schema_search_path
X
Xavier Noria 已提交
926
        @schema_search_path ||= query('SHOW search_path', 'SCHEMA')[0][0]
927
      end
928

929 930
      # Returns the current client message level.
      def client_min_messages
931
        query('SHOW client_min_messages', 'SCHEMA')[0][0]
932 933 934 935
      end

      # Set the client message level.
      def client_min_messages=(level)
936
        execute("SET client_min_messages TO '#{level}'", 'SCHEMA')
937 938 939 940
      end

      # Returns the sequence name for a table's primary key or some other specified key.
      def default_sequence_name(table_name, pk = nil) #:nodoc:
941 942 943 944 945 946 947 948 949 950
        serial_sequence(table_name, pk || 'id').split('.').last
      rescue ActiveRecord::StatementInvalid
        "#{table_name}_#{pk || 'id'}_seq"
      end

      def serial_sequence(table, column)
        result = exec_query(<<-eosql, 'SCHEMA', [[nil, table], [nil, column]])
          SELECT pg_get_serial_sequence($1, $2)
        eosql
        result.rows.first.first
951 952
      end

953 954
      # Resets the sequence of a table's primary key to the maximum value.
      def reset_pk_sequence!(table, pk = nil, sequence = nil) #:nodoc:
955 956
        unless pk and sequence
          default_pk, default_sequence = pk_and_sequence_for(table)
957

958 959 960
          pk ||= default_pk
          sequence ||= default_sequence
        end
961

962 963 964 965 966
        if @logger && pk && !sequence
          @logger.warn "#{table} has primary key #{pk} with no default sequence"
        end

        if pk && sequence
967
          quoted_sequence = quote_table_name(sequence)
G
Guillermo Iguaran 已提交
968

969 970 971
          select_value <<-end_sql, 'Reset sequence'
            SELECT setval('#{quoted_sequence}', (SELECT COALESCE(MAX(#{quote_column_name pk})+(SELECT increment_by FROM #{quoted_sequence}), (SELECT min_value FROM #{quoted_sequence})) FROM #{quote_table_name(table)}), false)
          end_sql
972 973 974
        end
      end

975 976
      # Returns a table's primary key and belonging sequence.
      def pk_and_sequence_for(table) #:nodoc:
977 978
        # First try looking for a sequence with a dependency on the
        # given table's primary key.
979
        result = exec_query(<<-end_sql, 'SCHEMA').rows.first
980
          SELECT attr.attname, ns.nspname, seq.relname
981
          FROM pg_class seq
A
Akira Matsuda 已提交
982
          INNER JOIN pg_depend dep ON seq.oid = dep.objid
983 984
          INNER JOIN pg_attribute attr ON attr.attrelid = dep.refobjid AND attr.attnum = dep.refobjsubid
          INNER JOIN pg_constraint cons ON attr.attrelid = cons.conrelid AND attr.attnum = cons.conkey[1]
985
          INNER JOIN pg_namespace ns ON seq.relnamespace = ns.oid
986 987 988
          WHERE seq.relkind  = 'S'
            AND cons.contype = 'p'
            AND dep.refobjid = '#{quote_table_name(table)}'::regclass
989
        end_sql
990

991
        # [primary_key, sequence]
992 993 994 995 996
        if result.second ==  'public' then
          sequence = result.last
        else
          sequence = result.second+'.'+result.last
        end
G
Guillermo Iguaran 已提交
997

998
        [result.first, sequence]
999 1000
      rescue
        nil
1001 1002
      end

1003 1004
      # Returns just a table's primary key
      def primary_key(table)
1005
        row = exec_query(<<-end_sql, 'SCHEMA', [[nil, table]]).rows.first
1006
          SELECT DISTINCT(attr.attname)
1007 1008 1009 1010 1011
          FROM pg_attribute attr
          INNER JOIN pg_depend dep ON attr.attrelid = dep.refobjid AND attr.attnum = dep.refobjsubid
          INNER JOIN pg_constraint cons ON attr.attrelid = cons.conrelid AND attr.attnum = cons.conkey[1]
          WHERE cons.contype = 'p'
            AND dep.refobjid = $1::regclass
1012 1013 1014
        end_sql

        row && row.first
1015 1016
      end

1017
      # Renames a table.
1018 1019 1020
      #
      # Example:
      #   rename_table('octopuses', 'octopi')
1021
      def rename_table(name, new_name)
1022
        clear_cache!
1023
        execute "ALTER TABLE #{quote_table_name(name)} RENAME TO #{quote_table_name(new_name)}"
1024
      end
1025

1026 1027
      # Adds a new column to the named table.
      # See TableDefinition#column for details of the options you can use.
S
Scott Barron 已提交
1028
      def add_column(table_name, column_name, type, options = {})
1029
        clear_cache!
1030 1031
        add_column_sql = "ALTER TABLE #{quote_table_name(table_name)} ADD COLUMN #{quote_column_name(column_name)} #{type_to_sql(type, options[:limit], options[:precision], options[:scale])}"
        add_column_options!(add_column_sql, options)
1032

1033
        execute add_column_sql
S
Scott Barron 已提交
1034
      end
D
Initial  
David Heinemeier Hansson 已提交
1035

1036 1037
      # Changes the column of a table.
      def change_column(table_name, column_name, type, options = {})
1038
        clear_cache!
1039 1040
        quoted_table_name = quote_table_name(table_name)

1041
        execute "ALTER TABLE #{quoted_table_name} ALTER COLUMN #{quote_column_name(column_name)} TYPE #{type_to_sql(type, options[:limit], options[:precision], options[:scale])}"
1042

1043 1044
        change_column_default(table_name, column_name, options[:default]) if options_include_default?(options)
        change_column_null(table_name, column_name, options[:null], options[:default]) if options.key?(:null)
1045
      end
1046

1047 1048
      # Changes the default value of a table column.
      def change_column_default(table_name, column_name, default)
1049
        clear_cache!
1050
        execute "ALTER TABLE #{quote_table_name(table_name)} ALTER COLUMN #{quote_column_name(column_name)} SET DEFAULT #{quote(default)}"
1051
      end
1052

1053
      def change_column_null(table_name, column_name, null, default = nil)
1054
        clear_cache!
1055
        unless null || default.nil?
1056
          execute("UPDATE #{quote_table_name(table_name)} SET #{quote_column_name(column_name)}=#{quote(default)} WHERE #{quote_column_name(column_name)} IS NULL")
1057
        end
1058
        execute("ALTER TABLE #{quote_table_name(table_name)} ALTER #{quote_column_name(column_name)} #{null ? 'DROP' : 'SET'} NOT NULL")
1059 1060
      end

1061 1062
      # Renames a column in a table.
      def rename_column(table_name, column_name, new_column_name)
1063
        clear_cache!
1064
        execute "ALTER TABLE #{quote_table_name(table_name)} RENAME COLUMN #{quote_column_name(column_name)} TO #{quote_column_name(new_column_name)}"
1065
      end
1066

1067 1068 1069 1070
      def remove_index!(table_name, index_name) #:nodoc:
        execute "DROP INDEX #{quote_table_name(index_name)}"
      end

1071 1072 1073 1074
      def rename_index(table_name, old_name, new_name)
        execute "ALTER INDEX #{quote_column_name(old_name)} RENAME TO #{quote_table_name(new_name)}"
      end

1075 1076
      def index_name_length
        63
1077
      end
1078

1079 1080
      # Maps logical Rails types to PostgreSQL-specific data types.
      def type_to_sql(type, limit = nil, precision = nil, scale = nil)
1081
        return super unless type.to_s == 'integer'
1082
        return 'integer' unless limit
1083

1084
        case limit
1085 1086 1087
          when 1, 2; 'smallint'
          when 3, 4; 'integer'
          when 5..8; 'bigint'
1088
          else raise(ActiveRecordError, "No integer type has byte size #{limit}. Use a numeric with precision 0 instead.")
1089 1090
        end
      end
1091

1092
      # Returns a SELECT DISTINCT clause for a given set of columns and a given ORDER BY clause.
1093 1094 1095
      #
      # PostgreSQL requires the ORDER BY columns in the select list for distinct queries, and
      # requires that the ORDER BY include the distinct column.
1096
      #
1097
      #   distinct("posts.id", "posts.created_at desc")
1098 1099
      def distinct(columns, orders) #:nodoc:
        return "DISTINCT #{columns}" if orders.empty?
1100

1101 1102
        # Construct a clean list of column names from the ORDER BY clause, removing
        # any ASC/DESC modifiers
1103
        order_columns = orders.collect { |s| s.gsub(/\s+(ASC|DESC)\s*/i, '') }
1104
        order_columns.delete_if { |c| c.blank? }
1105
        order_columns = order_columns.zip((0...order_columns.size).to_a).map { |s,i| "#{s} AS alias_#{i}" }
1106

1107
        "DISTINCT #{columns}, #{order_columns * ', '}"
1108
      end
1109

1110
      module Utils
1111 1112
        extend self

1113 1114 1115 1116 1117 1118 1119 1120 1121 1122
        # Returns an array of <tt>[schema_name, table_name]</tt> extracted from +name+.
        # +schema_name+ is nil if not specified in +name+.
        # +schema_name+ and +table_name+ exclude surrounding quotes (regardless of whether provided in +name+)
        # +name+ supports the range of schema/table references understood by PostgreSQL, for example:
        #
        # * <tt>table_name</tt>
        # * <tt>"table.name"</tt>
        # * <tt>schema_name.table_name</tt>
        # * <tt>schema_name."table.name"</tt>
        # * <tt>"schema.name"."table name"</tt>
1123
        def extract_schema_and_table(name)
1124 1125 1126 1127 1128
          table, schema = name.scan(/[^".\s]+|"[^"]*"/)[0..1].collect{|m| m.gsub(/(^"|"$)/,'') }.reverse
          [schema, table]
        end
      end

1129
      protected
1130
        # Returns the version of the connected PostgreSQL server.
1131
        def postgresql_version
1132
          @connection.server_version
1133 1134
        end

1135 1136 1137
        def translate_exception(exception, message)
          case exception.message
          when /duplicate key value violates unique constraint/
1138
            RecordNotUnique.new(message, exception)
1139
          when /violates foreign key constraint/
1140
            InvalidForeignKey.new(message, exception)
1141 1142 1143 1144 1145
          else
            super
          end
        end

D
Initial  
David Heinemeier Hansson 已提交
1146
      private
1147 1148
        FEATURE_NOT_SUPPORTED = "0A000" # :nodoc:

1149 1150
        def exec_no_cache(sql, binds)
          @connection.async_exec(sql)
1151
        end
1152

1153
        def exec_cache(sql, binds)
1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188
          begin
            stmt_key = prepare_statement sql

            # Clear the queue
            @connection.get_last_result
            @connection.send_query_prepared(stmt_key, binds.map { |col, val|
              type_cast(val, col)
            })
            @connection.block
            @connection.get_last_result
          rescue PGError => e
            # Get the PG code for the failure.  Annoyingly, the code for
            # prepared statements whose return value may have changed is
            # FEATURE_NOT_SUPPORTED.  Check here for more details:
            # http://git.postgresql.org/gitweb/?p=postgresql.git;a=blob;f=src/backend/utils/cache/plancache.c#l573
            code = e.result.result_error_field(PGresult::PG_DIAG_SQLSTATE)
            if FEATURE_NOT_SUPPORTED == code
              @statements.delete sql_key(sql)
              retry
            else
              raise e
            end
          end
        end

        # Returns the statement identifier for the client side cache
        # of statements
        def sql_key(sql)
          "#{schema_search_path}-#{sql}"
        end

        # Prepare the statement if it hasn't been prepared, return
        # the statement key.
        def prepare_statement(sql)
          sql_key = sql_key(sql)
1189
          unless @statements.key? sql_key
1190
            nextkey = @statements.next_key
1191
            @connection.prepare nextkey, sql
1192
            @statements[sql_key] = nextkey
1193
          end
1194
          @statements[sql_key]
1195
        end
1196

P
Pratik Naik 已提交
1197
        # The internal PostgreSQL identifier of the money data type.
1198
        MONEY_COLUMN_TYPE_OID = 790 #:nodoc:
1199 1200
        # The internal PostgreSQL identifier of the BYTEA data type.
        BYTEA_COLUMN_TYPE_OID = 17 #:nodoc:
1201 1202 1203 1204 1205 1206 1207 1208 1209

        # Connects to a PostgreSQL server and sets up the adapter depending on the
        # connected server's characteristics.
        def connect
          @connection = PGconn.connect(*@connection_parameters)

          # Money type has a fixed precision of 10 in PostgreSQL 8.2 and below, and as of
          # PostgreSQL 8.3 it has a fixed precision of 19. PostgreSQLColumn.extract_precision
          # should know about this but can't detect it there, so deal with it here.
1210 1211
          PostgreSQLColumn.money_precision = (postgresql_version >= 80300) ? 19 : 10

1212 1213 1214
          configure_connection
        end

1215
        # Configures the encoding, verbosity, schema search path, and time zone of the connection.
1216
        # This is called by #connect and should not be called manually.
1217 1218
        def configure_connection
          if @config[:encoding]
1219
            @connection.set_client_encoding(@config[:encoding])
1220
          end
1221 1222
          self.client_min_messages = @config[:min_messages] if @config[:min_messages]
          self.schema_search_path = @config[:schema_search_path] || @config[:schema_order]
1223 1224 1225 1226

          # Use standard-conforming strings if available so we don't have to do the E'...' dance.
          set_standard_conforming_strings

1227
          # If using Active Record's time zone support configure the connection to return
1228
          # TIMESTAMP WITH ZONE types in UTC.
1229
          if ActiveRecord::Base.default_timezone == :utc
1230
            execute("SET time zone 'UTC'", 'SCHEMA')
1231
          elsif @local_tz
1232
            execute("SET time zone '#{@local_tz}'", 'SCHEMA')
1233
          end
1234 1235
        end

1236
        # Returns the current ID of a table's sequence.
1237 1238 1239
        def last_insert_id(sequence_name) #:nodoc:
          r = exec_query("SELECT currval($1)", 'SQL', [[nil, sequence_name]])
          Integer(r.rows.first.first)
D
Initial  
David Heinemeier Hansson 已提交
1240 1241
        end

1242
        # Executes a SELECT query and returns the results, performing any data type
1243
        # conversions that are required to be performed here instead of in PostgreSQLColumn.
1244
        def select(sql, name = nil, binds = [])
A
Aaron Patterson 已提交
1245
          exec_query(sql, name, binds).to_a
1246 1247 1248
        end

        def select_raw(sql, name = nil)
1249
          res = execute(sql, name)
1250
          results = result_as_array(res)
1251
          fields = res.fields
1252
          res.clear
1253
          return fields, results
M
Marcel Molina 已提交
1254 1255
        end

1256
        # Returns the list of a table's column names, data types, and default values.
1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273
        #
        # The underlying query is roughly:
        #  SELECT column.name, column.type, default.value
        #    FROM column LEFT JOIN default
        #      ON column.table_id = default.table_id
        #     AND column.num = default.column_num
        #   WHERE column.table_id = get_table_id('table_name')
        #     AND column.num > 0
        #     AND NOT column.is_dropped
        #   ORDER BY column.num
        #
        # If the table name is not prefixed with a schema, the database will
        # take the first match from the schema search path.
        #
        # Query implementation notes:
        #  - format_type includes the column size constraint, e.g. varchar(50)
        #  - ::regclass is a function that gives the id for a table name
1274
        def column_definitions(table_name) #:nodoc:
1275
          exec_query(<<-end_sql, 'SCHEMA').rows
1276
            SELECT a.attname, format_type(a.atttypid, a.atttypmod), d.adsrc, a.attnotnull
1277 1278
              FROM pg_attribute a LEFT JOIN pg_attrdef d
                ON a.attrelid = d.adrelid AND a.attnum = d.adnum
1279
             WHERE a.attrelid = '#{quote_table_name(table_name)}'::regclass
1280 1281 1282
               AND a.attnum > 0 AND NOT a.attisdropped
             ORDER BY a.attnum
          end_sql
D
Initial  
David Heinemeier Hansson 已提交
1283
        end
1284 1285

        def extract_pg_identifier_from_name(name)
1286
          match_data = name.start_with?('"') ? name.match(/\"([^\"]+)\"/) : name.match(/([^\.]+)/)
1287 1288

          if match_data
1289 1290
            rest = name[match_data[0].length, name.length]
            rest = rest[1, rest.length] if rest.start_with? "."
J
José Valim 已提交
1291
            [match_data[1], (rest.length > 0 ? rest : nil)]
1292 1293
          end
        end
1294

1295 1296 1297 1298 1299
        def extract_table_ref_from_insert_sql(sql)
          sql[/into\s+([^\(]*).*values\s*\(/i]
          $1.strip if $1
        end

1300 1301 1302
        def table_definition
          TableDefinition.new(self)
        end
D
Initial  
David Heinemeier Hansson 已提交
1303 1304 1305
    end
  end
end