postgresql_adapter.rb 43.9 KB
Newer Older
D
Initial  
David Heinemeier Hansson 已提交
1
require 'active_record/connection_adapters/abstract_adapter'
2
require 'active_support/core_ext/object/blank'
3
require 'active_record/connection_adapters/statement_pool'
4 5 6

# Make sure we're using pg high enough for PGResult#values
gem 'pg', '~> 0.11'
7
require 'pg'
D
Initial  
David Heinemeier Hansson 已提交
8 9

module ActiveRecord
J
Jon Leighton 已提交
10
  module Core::ClassMethods
D
Initial  
David Heinemeier Hansson 已提交
11
    # Establishes a connection to the database that's used by all Active Record objects
J
Jon Leighton 已提交
12
    def postgresql_connection(config) # :nodoc:
13
      config = config.symbolize_keys
D
Initial  
David Heinemeier Hansson 已提交
14
      host     = config[:host]
15
      port     = config[:port] || 5432
16 17
      username = config[:username].to_s if config[:username]
      password = config[:password].to_s if config[:password]
D
Initial  
David Heinemeier Hansson 已提交
18

19
      if config.key?(:database)
D
Initial  
David Heinemeier Hansson 已提交
20 21 22 23 24
        database = config[:database]
      else
        raise ArgumentError, "No database specified. Missing argument: database."
      end

25
      # The postgres drivers don't allow the creation of an unconnected PGconn object,
26 27 28 29
      # so just pass a nil connection object for the time being.
      ConnectionAdapters::PostgreSQLAdapter.new(nil, logger, [host, port, nil, nil, database, username, password], config)
    end
  end
30

31 32 33 34 35 36 37
  module ConnectionAdapters
    # PostgreSQL-specific extensions to column definitions in a table.
    class PostgreSQLColumn < Column #:nodoc:
      # Instantiates a new PostgreSQL column definition in a table.
      def initialize(name, default, sql_type = nil, null = true)
        super(name, self.class.extract_value_from_default(default), sql_type, null)
      end
38

39 40 41
      # :stopdoc:
      class << self
        attr_accessor :money_precision
42 43 44 45 46 47 48 49 50 51
        def string_to_time(string)
          return string unless String === string

          case string
          when 'infinity'  then 1.0 / 0.0
          when '-infinity' then -1.0 / 0.0
          else
            super
          end
        end
52

53 54
        def cast_hstore(object)
          if Hash === object
A
Aaron Patterson 已提交
55 56 57
            object.map { |k,v|
              "#{escape_hstore(k)}=>#{escape_hstore(v)}"
            }.join ', '
58
          else
A
Aaron Patterson 已提交
59 60
            kvs = object.scan(/(?<!\\)".*?(?<!\\)"/).map { |o|
              unescape_hstore(o[1...-1])
61
            }
A
Aaron Patterson 已提交
62 63 64 65 66
            Hash[kvs.each_slice(2).to_a]
          end
        end

        private
67 68 69 70 71 72 73 74 75 76
        HSTORE_ESCAPE = {
            ' '  => '\\ ',
            '\\' => '\\\\',
            '"'  => '\\"',
            '='  => '\\=',
        }
        HSTORE_ESCAPE_RE   = Regexp.union(HSTORE_ESCAPE.keys)
        HSTORE_UNESCAPE    = HSTORE_ESCAPE.invert
        HSTORE_UNESCAPE_RE = Regexp.union(HSTORE_UNESCAPE.keys)

A
Aaron Patterson 已提交
77
        def unescape_hstore(value)
78 79
          value.gsub(HSTORE_UNESCAPE_RE) do |match|
            HSTORE_UNESCAPE[match]
A
Aaron Patterson 已提交
80 81 82 83
          end
        end

        def escape_hstore(value)
84 85
          value.gsub(HSTORE_ESCAPE_RE) do |match|
            HSTORE_ESCAPE[match]
86
          end
87
        end
88 89 90
      end
      # :startdoc:

91
      private
92
        def extract_limit(sql_type)
93 94 95 96 97
          case sql_type
          when /^bigint/i;    8
          when /^smallint/i;  2
          else super
          end
98 99
        end

100 101 102 103 104
        # Extracts the scale from PostgreSQL-specific data types.
        def extract_scale(sql_type)
          # Money type has a fixed scale of 2.
          sql_type =~ /^money/ ? 2 : super
        end
105

106 107
        # Extracts the precision from PostgreSQL-specific data types.
        def extract_precision(sql_type)
108 109 110 111 112
          if sql_type == 'money'
            self.class.money_precision
          else
            super
          end
113
        end
114

115 116 117
        # Maps PostgreSQL-specific data types to logical Rails types.
        def simplified_type(field_type)
          case field_type
118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166
          # Numeric and monetary types
          when /^(?:real|double precision)$/
            :float
          # Monetary types
          when 'money'
            :decimal
          when 'hstore'
            :hstore
          # Character types
          when /^(?:character varying|bpchar)(?:\(\d+\))?$/
            :string
          # Binary data types
          when 'bytea'
            :binary
          # Date/time types
          when /^timestamp with(?:out)? time zone$/
            :datetime
          when 'interval'
            :string
          # Geometric types
          when /^(?:point|line|lseg|box|"?path"?|polygon|circle)$/
            :string
          # Network address types
          when /^(?:cidr|inet|macaddr)$/
            :string
          # Bit strings
          when /^bit(?: varying)?(?:\(\d+\))?$/
            :string
          # XML type
          when 'xml'
            :xml
          # tsvector type
          when 'tsvector'
            :tsvector
          # Arrays
          when /^\D+\[\]$/
            :string
          # Object identifier types
          when 'oid'
            :integer
          # UUID type
          when 'uuid'
            :string
          # Small and big integer types
          when /^(?:small|big)int$/
            :integer
          # Pass through all types that are not specific to PostgreSQL.
          else
            super
167 168
          end
        end
169

170 171 172
        # Extracts the value from a PostgreSQL column default definition.
        def self.extract_value_from_default(default)
          case default
173 174 175 176 177 178 179 180
            # This is a performance optimization for Ruby 1.9.2 in development.
            # If the value is nil, we return nil straight away without checking
            # the regular expressions. If we check each regular expression,
            # Regexp#=== will call NilClass#to_str, which will trigger
            # method_missing (defined by whiny nil in ActiveSupport) which
            # makes this method very very slow.
            when NilClass
              nil
181
            # Numeric types
182 183
            when /\A\(?(-?\d+(\.\d*)?\)?)\z/
              $1
184
            # Character types
185
            when /\A'(.*)'::(?:character varying|bpchar|text)\z/m
186
              $1
187 188 189
            # Character types (8.1 formatting)
            when /\AE'(.*)'::(?:character varying|bpchar|text)\z/m
              $1.gsub(/\\(\d\d\d)/) { $1.oct.chr }
190
            # Binary data types
191
            when /\A'(.*)'::bytea\z/m
192 193
              $1
            # Date/time types
194
            when /\A'(.+)'::(?:time(?:stamp)? with(?:out)? time zone|date)\z/
195
              $1
196
            when /\A'(.*)'::interval\z/
197 198
              $1
            # Boolean type
199
            when 'true'
200
              true
201
            when 'false'
202 203
              false
            # Geometric types
204
            when /\A'(.*)'::(?:point|line|lseg|box|"?path"?|polygon|circle)\z/
205 206
              $1
            # Network address types
207
            when /\A'(.*)'::(?:cidr|inet|macaddr)\z/
208 209
              $1
            # Bit string types
210
            when /\AB'(.*)'::"?bit(?: varying)?"?\z/
211 212
              $1
            # XML type
213
            when /\A'(.*)'::xml\z/m
214 215
              $1
            # Arrays
216
            when /\A'(.*)'::"?\D+"?\[\]\z/
217 218
              $1
            # Object identifier types
219
            when /\A-?\d+\z/
220 221 222
              $1
            else
              # Anything else is blank, some user type, or some function
223
              # and we can't know the value of that, so return nil.
224 225 226
              nil
          end
        end
D
Initial  
David Heinemeier Hansson 已提交
227 228
    end

229 230
    # The PostgreSQL adapter works both with the native C (http://ruby.scripting.ca/postgres/) and the pure
    # Ruby (available both as gem and from http://rubyforge.org/frs/?group_id=234&release_id=1944) drivers.
231 232 233
    #
    # Options:
    #
P
Pratik Naik 已提交
234 235 236 237 238
    # * <tt>:host</tt> - Defaults to "localhost".
    # * <tt>:port</tt> - Defaults to 5432.
    # * <tt>:username</tt> - Defaults to nothing.
    # * <tt>:password</tt> - Defaults to nothing.
    # * <tt>:database</tt> - The name of the database. No default, must be provided.
239
    # * <tt>:schema_search_path</tt> - An optional schema search path for the connection given
240
    #   as a string of comma-separated schema names. This is backward-compatible with the <tt>:schema_order</tt> option.
241
    # * <tt>:encoding</tt> - An optional client encoding that is used in a <tt>SET client_encoding TO
242
    #   <encoding></tt> call on the connection.
243
    # * <tt>:min_messages</tt> - An optional client min messages that is used in a
244
    #   <tt>SET client_min_messages TO <min_messages></tt> call on the connection.
245
    class PostgreSQLAdapter < AbstractAdapter
246 247 248 249 250
      class TableDefinition < ActiveRecord::ConnectionAdapters::TableDefinition
        def xml(*args)
          options = args.extract_options!
          column(args[0], 'xml', options)
        end
251 252 253 254 255

        def tsvector(*args)
          options = args.extract_options!
          column(args[0], 'tsvector', options)
        end
256 257 258 259

        def hstore(name, options = {})
          column(name, 'hstore', options)
        end
260 261
      end

262
      ADAPTER_NAME = 'PostgreSQL'
263 264

      NATIVE_DATABASE_TYPES = {
265
        :primary_key => "serial primary key",
266 267 268 269 270 271 272 273 274 275
        :string      => { :name => "character varying", :limit => 255 },
        :text        => { :name => "text" },
        :integer     => { :name => "integer" },
        :float       => { :name => "float" },
        :decimal     => { :name => "decimal" },
        :datetime    => { :name => "timestamp" },
        :timestamp   => { :name => "timestamp" },
        :time        => { :name => "time" },
        :date        => { :name => "date" },
        :binary      => { :name => "bytea" },
276
        :boolean     => { :name => "boolean" },
277 278
        :xml         => { :name => "xml" },
        :tsvector    => { :name => "tsvector" }
279 280
      }

281
      # Returns 'PostgreSQL' as adapter name for identification purposes.
282
      def adapter_name
283
        ADAPTER_NAME
284 285
      end

286 287
      # Returns +true+, since this connection adapter supports prepared statement
      # caching.
288 289 290 291
      def supports_statement_cache?
        true
      end

292 293 294 295
      def supports_index_sort_order?
        true
      end

296 297 298 299
      class StatementPool < ConnectionAdapters::StatementPool
        def initialize(connection, max)
          super
          @counter = 0
300
          @cache   = Hash.new { |h,pid| h[pid] = {} }
301 302
        end

303 304 305 306
        def each(&block); cache.each(&block); end
        def key?(key);    cache.key?(key); end
        def [](key);      cache[key]; end
        def length;       cache.length; end
307 308 309 310 311 312

        def next_key
          "a#{@counter + 1}"
        end

        def []=(sql, key)
313 314
          while @max <= cache.size
            dealloc(cache.shift.last)
315 316
          end
          @counter += 1
317
          cache[sql] = key
318 319 320
        end

        def clear
321
          cache.each_value do |stmt_key|
322 323
            dealloc stmt_key
          end
324
          cache.clear
325 326
        end

327 328 329 330 331
        def delete(sql_key)
          dealloc cache[sql_key]
          cache.delete sql_key
        end

332
        private
333 334 335 336
        def cache
          @cache[$$]
        end

337
        def dealloc(key)
338 339 340 341 342 343 344
          @connection.query "DEALLOCATE #{key}" if connection_active?
        end

        def connection_active?
          @connection.status == PGconn::CONNECTION_OK
        rescue PGError
          false
345 346 347
        end
      end

348 349
      # Initializes and connects a PostgreSQL adapter.
      def initialize(connection, logger, connection_parameters, config)
350
        super(connection, logger)
351
        @connection_parameters, @config = connection_parameters, config
352
        @visitor = Arel::Visitors::PostgreSQL.new self
353

354 355
        # @local_tz is initialized as nil to avoid warnings when connect tries to use it
        @local_tz = nil
356 357
        @table_alias_length = nil

358
        connect
359 360
        @statements = StatementPool.new @connection,
                                        config.fetch(:statement_limit) { 1000 }
361 362 363 364 365

        if postgresql_version < 80200
          raise "Your version of PostgreSQL (#{postgresql_version}) is too old, please upgrade!"
        end

366
        @local_tz = execute('SHOW TIME ZONE', 'SCHEMA').first["TimeZone"]
367 368
      end

X
Xavier Noria 已提交
369
      # Clears the prepared statements cache.
370 371 372 373
      def clear_cache!
        @statements.clear
      end

374 375
      # Is this connection alive and ready for queries?
      def active?
376 377
        @connection.status == PGconn::CONNECTION_OK
      rescue PGError
378
        false
379 380 381 382
      end

      # Close then reopen the connection.
      def reconnect!
383 384 385
        clear_cache!
        @connection.reset
        configure_connection
386
      end
387

388 389 390 391 392
      def reset!
        clear_cache!
        super
      end

393 394
      # Disconnects from the database if already connected. Otherwise, this
      # method does nothing.
395
      def disconnect!
396
        clear_cache!
397 398
        @connection.close rescue nil
      end
399

400
      def native_database_types #:nodoc:
401
        NATIVE_DATABASE_TYPES
402
      end
403

404
      # Returns true, since this connection adapter supports migrations.
405 406
      def supports_migrations?
        true
407 408
      end

409
      # Does PostgreSQL support finding primary key on non-Active Record tables?
410 411 412 413
      def supports_primary_key? #:nodoc:
        true
      end

414 415 416
      # Enable standard-conforming strings if available.
      def set_standard_conforming_strings
        old, self.client_min_messages = client_min_messages, 'panic'
417
        execute('SET standard_conforming_strings = on', 'SCHEMA') rescue nil
418 419
      ensure
        self.client_min_messages = old
420 421
      end

422
      def supports_insert_with_returning?
423
        true
424 425
      end

426 427 428
      def supports_ddl_transactions?
        true
      end
429

430
      # Returns true, since this connection adapter supports savepoints.
431 432 433
      def supports_savepoints?
        true
      end
434

435 436 437 438 439
      # Returns true.
      def supports_explain?
        true
      end

440
      # Returns the configured supported identifier length supported by PostgreSQL
441
      def table_alias_length
442
        @table_alias_length ||= query('SHOW max_identifier_length')[0][0].to_i
443
      end
444

445 446
      # QUOTING ==================================================

447
      # Escapes binary strings for bytea input to the database.
448 449
      def escape_bytea(value)
        @connection.escape_bytea(value) if value
450 451 452 453 454
      end

      # Unescapes bytea output from a database to the binary string it represents.
      # NOTE: This is NOT an inverse of escape_bytea! This is only to be used
      #       on escaped binary output from database drive.
455 456
      def unescape_bytea(value)
        @connection.unescape_bytea(value) if value
457 458
      end

459 460
      # Quotes PostgreSQL-specific data types for SQL input.
      def quote(value, column = nil) #:nodoc:
461 462
        return super unless column

A
Aaron Patterson 已提交
463
        case value
464 465 466
        when Float
          return super unless value.infinite? && column.type == :datetime
          "'#{value.to_s.downcase}'"
A
Aaron Patterson 已提交
467 468
        when Numeric
          return super unless column.sql_type == 'money'
469
          # Not truly string input, so doesn't require (or allow) escape string syntax.
470
          "'#{value}'"
A
Aaron Patterson 已提交
471 472 473 474 475 476 477 478 479 480 481
        when String
          case column.sql_type
          when 'bytea' then "'#{escape_bytea(value)}'"
          when 'xml'   then "xml '#{quote_string(value)}'"
          when /^bit/
            case value
            when /^[01]*$/      then "B'#{value}'" # Bit-string notation
            when /^[0-9A-F]*$/i then "X'#{value}'" # Hexadecimal notation
            end
          else
            super
482
          end
483 484 485 486 487
        else
          super
        end
      end

488 489 490 491 492 493
      def type_cast(value, column)
        return super unless column

        case value
        when String
          return super unless 'bytea' == column.sql_type
494
          { :value => value, :format => 1 }
495 496 497 498 499
        else
          super
        end
      end

500 501 502
      # Quotes strings for use in SQL input.
      def quote_string(s) #:nodoc:
        @connection.escape(s)
503 504
      end

505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523
      # Checks the following cases:
      #
      # - table_name
      # - "table.name"
      # - schema_name.table_name
      # - schema_name."table.name"
      # - "schema.name".table_name
      # - "schema.name"."table.name"
      def quote_table_name(name)
        schema, name_part = extract_pg_identifier_from_name(name.to_s)

        unless name_part
          quote_column_name(schema)
        else
          table_name, name_part = extract_pg_identifier_from_name(name_part)
          "#{quote_column_name(schema)}.#{quote_column_name(table_name)}"
        end
      end

524 525
      # Quotes column names for use in SQL queries.
      def quote_column_name(name) #:nodoc:
526
        PGconn.quote_ident(name.to_s)
527 528
      end

529 530 531
      # Quote date/time values for use in SQL input. Includes microseconds
      # if the value is a Time responding to usec.
      def quoted_date(value) #:nodoc:
532 533 534 535 536
        if value.acts_like?(:time) && value.respond_to?(:usec)
          "#{super}.#{sprintf("%06d", value.usec)}"
        else
          super
        end
537 538
      end

539 540
      # Set the authorized user for this session
      def session_auth=(user)
541
        clear_cache!
A
Aaron Patterson 已提交
542
        exec_query "SET SESSION AUTHORIZATION #{user}"
543 544
      end

545 546
      # REFERENTIAL INTEGRITY ====================================

547
      def supports_disable_referential_integrity? #:nodoc:
548
        true
549 550
      end

551
      def disable_referential_integrity #:nodoc:
552
        if supports_disable_referential_integrity? then
553 554
          execute(tables.collect { |name| "ALTER TABLE #{quote_table_name(name)} DISABLE TRIGGER ALL" }.join(";"))
        end
555 556
        yield
      ensure
557
        if supports_disable_referential_integrity? then
558 559
          execute(tables.collect { |name| "ALTER TABLE #{quote_table_name(name)} ENABLE TRIGGER ALL" }.join(";"))
        end
560
      end
561 562 563

      # DATABASE STATEMENTS ======================================

564
      def explain(arel, binds = [])
X
Xavier Noria 已提交
565
        sql = "EXPLAIN #{to_sql(arel)}"
566
        ExplainPrettyPrinter.new.pp(exec_query(sql, 'EXPLAIN', binds))
X
Xavier Noria 已提交
567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605
      end

      class ExplainPrettyPrinter # :nodoc:
        # Pretty prints the result of a EXPLAIN in a way that resembles the output of the
        # PostgreSQL shell:
        #
        #                                     QUERY PLAN
        #   ------------------------------------------------------------------------------
        #    Nested Loop Left Join  (cost=0.00..37.24 rows=8 width=0)
        #      Join Filter: (posts.user_id = users.id)
        #      ->  Index Scan using users_pkey on users  (cost=0.00..8.27 rows=1 width=4)
        #            Index Cond: (id = 1)
        #      ->  Seq Scan on posts  (cost=0.00..28.88 rows=8 width=4)
        #            Filter: (posts.user_id = 1)
        #   (6 rows)
        #
        def pp(result)
          header = result.columns.first
          lines  = result.rows.map(&:first)

          # We add 2 because there's one char of padding at both sides, note
          # the extra hyphens in the example above.
          width = [header, *lines].map(&:length).max + 2

          pp = []

          pp << header.center(width).rstrip
          pp << '-' * width

          pp += lines.map {|line| " #{line}"}

          nrows = result.rows.length
          rows_label = nrows == 1 ? 'row' : 'rows'
          pp << "(#{nrows} #{rows_label})"

          pp.join("\n") + "\n"
        end
      end

606 607 608 609 610 611
      # Executes a SELECT query and returns an array of rows. Each row is an
      # array of field values.
      def select_rows(sql, name = nil)
        select_raw(sql, name).last
      end

612
      # Executes an INSERT query and returns the new record's ID
613
      def insert_sql(sql, name = nil, pk = nil, id_value = nil, sequence_name = nil)
614 615 616 617 618
        unless pk
          # Extract the table from the insert sql. Yuck.
          table_ref = extract_table_ref_from_insert_sql(sql)
          pk = primary_key(table_ref) if table_ref
        end
619

620
        if pk
621 622 623
          select_value("#{sql} RETURNING #{quote_column_name(pk)}")
        else
          super
624
        end
625
      end
626
      alias :create :insert
627

628 629
      # create a 2D array representing the result set
      def result_as_array(res) #:nodoc:
630
        # check if we have any binary column and if they need escaping
631 632
        ftypes = Array.new(res.nfields) do |i|
          [i, res.ftype(i)]
633 634
        end

635 636 637 638 639 640
        rows = res.values
        return rows unless ftypes.any? { |_, x|
          x == BYTEA_COLUMN_TYPE_OID || x == MONEY_COLUMN_TYPE_OID
        }

        typehash = ftypes.group_by { |_, type| type }
641 642
        binaries = typehash[BYTEA_COLUMN_TYPE_OID] || []
        monies   = typehash[MONEY_COLUMN_TYPE_OID] || []
643 644 645

        rows.each do |row|
          # unescape string passed BYTEA field (OID == 17)
646 647
          binaries.each do |index, _|
            row[index] = unescape_bytea(row[index])
648 649 650 651 652 653
          end

          # If this is a money type column and there are any currency symbols,
          # then strip them off. Indeed it would be prettier to do this in
          # PostgreSQLColumn.string_to_decimal but would break form input
          # fields that call value_before_type_cast.
654
          monies.each do |index, _|
655 656 657 658 659 660 661 662 663 664
            data = row[index]
            # Because money output is formatted according to the locale, there are two
            # cases to consider (note the decimal separators):
            #  (1) $12,345,678.12
            #  (2) $12.345.678,12
            case data
            when /^-?\D+[\d,]+\.\d{2}$/  # (1)
              data.gsub!(/[^-\d.]/, '')
            when /^-?\D+[\d.]+,\d{2}$/  # (2)
              data.gsub!(/[^-\d,]/, '').sub!(/,/, '.')
665
            end
666 667 668 669 670 671
          end
        end
      end


      # Queries the database and returns the results in an Array-like object
672
      def query(sql, name = nil) #:nodoc:
673
        log(sql, name) do
674
          result_as_array @connection.async_exec(sql)
675
        end
676 677
      end

678
      # Executes an SQL statement, returning a PGresult object on success
679 680
      # or raising a PGError exception otherwise.
      def execute(sql, name = nil)
681
        log(sql, name) do
682
          @connection.async_exec(sql)
683
        end
684 685
      end

686 687
      def substitute_at(column, index)
        Arel.sql("$#{index + 1}")
688 689
      end

A
Aaron Patterson 已提交
690
      def exec_query(sql, name = 'SQL', binds = [])
691
        log(sql, name, binds) do
692 693
          result = binds.empty? ? exec_no_cache(sql, binds) :
                                  exec_cache(sql, binds)
694

695 696 697
          ret = ActiveRecord::Result.new(result.fields, result_as_array(result))
          result.clear
          return ret
698 699 700
        end
      end

701 702 703 704 705 706 707 708 709
      def exec_delete(sql, name = 'SQL', binds = [])
        log(sql, name, binds) do
          result = binds.empty? ? exec_no_cache(sql, binds) :
                                  exec_cache(sql, binds)
          affected = result.cmd_tuples
          result.clear
          affected
        end
      end
710
      alias :exec_update :exec_delete
711

712 713
      def sql_for_insert(sql, pk, id_value, sequence_name, binds)
        unless pk
714 715 716
          # Extract the table from the insert sql. Yuck.
          table_ref = extract_table_ref_from_insert_sql(sql)
          pk = primary_key(table_ref) if table_ref
717 718 719 720 721 722 723
        end

        sql = "#{sql} RETURNING #{quote_column_name(pk)}" if pk

        [sql, binds]
      end

724
      # Executes an UPDATE query and returns the number of affected tuples.
725
      def update_sql(sql, name = nil)
726
        super.cmd_tuples
727 728
      end

729 730
      # Begins a transaction.
      def begin_db_transaction
731 732 733
        execute "BEGIN"
      end

734 735
      # Commits a transaction.
      def commit_db_transaction
736 737
        execute "COMMIT"
      end
738

739 740
      # Aborts a transaction.
      def rollback_db_transaction
741 742
        execute "ROLLBACK"
      end
743

744 745
      def outside_transaction?
        @connection.transaction_status == PGconn::PQTRANS_IDLE
746
      end
747

J
Jonathan Viney 已提交
748 749 750 751 752 753 754 755
      def create_savepoint
        execute("SAVEPOINT #{current_savepoint_name}")
      end

      def rollback_to_savepoint
        execute("ROLLBACK TO SAVEPOINT #{current_savepoint_name}")
      end

756
      def release_savepoint
J
Jonathan Viney 已提交
757 758
        execute("RELEASE SAVEPOINT #{current_savepoint_name}")
      end
759

760 761
      # SCHEMA STATEMENTS ========================================

762 763 764
      # Drops the database specified on the +name+ attribute
      # and creates it again using the provided +options+.
      def recreate_database(name, options = {}) #:nodoc:
765
        drop_database(name)
766
        create_database(name, options)
767 768
      end

769
      # Create a new PostgreSQL database. Options include <tt>:owner</tt>, <tt>:template</tt>,
770 771
      # <tt>:encoding</tt>, <tt>:tablespace</tt>, and <tt>:connection_limit</tt> (note that MySQL uses
      # <tt>:charset</tt> while PostgreSQL uses <tt>:encoding</tt>).
772 773 774 775 776 777 778 779 780 781
      #
      # Example:
      #   create_database config[:database], config
      #   create_database 'foo_development', :encoding => 'unicode'
      def create_database(name, options = {})
        options = options.reverse_merge(:encoding => "utf8")

        option_string = options.symbolize_keys.sum do |key, value|
          case key
          when :owner
782
            " OWNER = \"#{value}\""
783
          when :template
784
            " TEMPLATE = \"#{value}\""
785 786 787
          when :encoding
            " ENCODING = '#{value}'"
          when :tablespace
788
            " TABLESPACE = \"#{value}\""
789 790 791 792 793 794 795
          when :connection_limit
            " CONNECTION LIMIT = #{value}"
          else
            ""
          end
        end

796
        execute "CREATE DATABASE #{quote_table_name(name)}#{option_string}"
797 798
      end

799
      # Drops a PostgreSQL database.
800 801 802 803
      #
      # Example:
      #   drop_database 'matt_development'
      def drop_database(name) #:nodoc:
804
        execute "DROP DATABASE IF EXISTS #{quote_table_name(name)}"
805 806
      end

807 808
      # Returns the list of all tables in the schema search path or a specified schema.
      def tables(name = nil)
809
        query(<<-SQL, 'SCHEMA').map { |row| row[0] }
810
          SELECT tablename
811 812 813 814 815
          FROM pg_tables
          WHERE schemaname = ANY (current_schemas(false))
        SQL
      end

816
      # Returns true if table exists.
817 818
      # If the schema is not specified as part of +name+ then it will only find tables within
      # the current schema search path (regardless of permissions to access tables in other schemas)
819
      def table_exists?(name)
820
        schema, table = Utils.extract_schema_and_table(name.to_s)
821
        return false unless table
822

823 824
        binds = [[nil, table]]
        binds << [nil, schema] if schema
825 826

        exec_query(<<-SQL, 'SCHEMA', binds).rows.first[0].to_i > 0
827
            SELECT COUNT(*)
A
Aaron Patterson 已提交
828 829 830 831 832
            FROM pg_class c
            LEFT JOIN pg_namespace n ON n.oid = c.relnamespace
            WHERE c.relkind in ('v','r')
            AND c.relname = $1
            AND n.nspname = #{schema ? '$2' : 'ANY (current_schemas(false))'}
833 834 835
        SQL
      end

836 837 838 839 840 841 842 843
      # Returns true if schema exists.
      def schema_exists?(name)
        exec_query(<<-SQL, 'SCHEMA', [[nil, name]]).rows.first[0].to_i > 0
          SELECT COUNT(*)
          FROM pg_namespace
          WHERE nspname = $1
        SQL
      end
844

845
      # Returns an array of indexes for the given table.
846
      def indexes(table_name, name = nil)
847
         result = query(<<-SQL, name)
848
           SELECT distinct i.relname, d.indisunique, d.indkey, pg_get_indexdef(d.indexrelid), t.oid
849 850 851
           FROM pg_class t
           INNER JOIN pg_index d ON t.oid = d.indrelid
           INNER JOIN pg_class i ON d.indexrelid = i.oid
852 853 854
           WHERE i.relkind = 'i'
             AND d.indisprimary = 'f'
             AND t.relname = '#{table_name}'
855
             AND i.relnamespace IN (SELECT oid FROM pg_namespace WHERE nspname = ANY (current_schemas(false)) )
856 857 858
          ORDER BY i.relname
        SQL

859

860
        result.map do |row|
861 862 863
          index_name = row[0]
          unique = row[1] == 't'
          indkey = row[2].split(" ")
864 865
          inddef = row[3]
          oid = row[4]
866

867 868
          columns = Hash[query(<<-SQL, "Columns for index #{row[0]} on #{table_name}")]
          SELECT a.attnum, a.attname
869 870 871 872 873
          FROM pg_attribute a
          WHERE a.attrelid = #{oid}
          AND a.attnum IN (#{indkey.join(",")})
          SQL

874
          column_names = columns.values_at(*indkey).compact
875 876 877 878

          # add info on sort order for columns (only desc order is explicitly specified, asc is the default)
          desc_order_columns = inddef.scan(/(\w+) DESC/).flatten
          orders = desc_order_columns.any? ? Hash[desc_order_columns.map {|order_column| [order_column, :desc]}] : {}
J
Jon Leighton 已提交
879

880
          column_names.empty? ? nil : IndexDefinition.new(table_name, index_name, unique, column_names, [], orders)
881
        end.compact
882 883
      end

884 885
      # Returns the list of all column definitions for a table.
      def columns(table_name, name = nil)
886
        # Limit, precision, and scale are all handled by the superclass.
887 888
        column_definitions(table_name).collect do |column_name, type, default, notnull|
          PostgreSQLColumn.new(column_name, default, type, notnull == 'f')
D
Initial  
David Heinemeier Hansson 已提交
889 890 891
        end
      end

892 893 894 895 896
      # Returns the current database name.
      def current_database
        query('select current_database()')[0][0]
      end

897 898 899 900 901
      # Returns the current schema name.
      def current_schema
        query('SELECT current_schema', 'SCHEMA')[0][0]
      end

902 903 904 905 906 907 908 909
      # Returns the current database encoding format.
      def encoding
        query(<<-end_sql)[0][0]
          SELECT pg_encoding_to_char(pg_database.encoding) FROM pg_database
          WHERE pg_database.datname LIKE '#{current_database}'
        end_sql
      end

910 911 912 913 914 915
      # Sets the schema search path to a string of comma-separated schema names.
      # Names beginning with $ have to be quoted (e.g. $user => '$user').
      # See: http://www.postgresql.org/docs/current/static/ddl-schemas.html
      #
      # This should be not be called manually but set in database.yml.
      def schema_search_path=(schema_csv)
916 917
        if schema_csv
          execute "SET search_path TO #{schema_csv}"
918
          @schema_search_path = schema_csv
919
        end
D
Initial  
David Heinemeier Hansson 已提交
920 921
      end

922 923
      # Returns the active schema search path.
      def schema_search_path
X
Xavier Noria 已提交
924
        @schema_search_path ||= query('SHOW search_path', 'SCHEMA')[0][0]
925
      end
926

927 928
      # Returns the current client message level.
      def client_min_messages
929
        query('SHOW client_min_messages', 'SCHEMA')[0][0]
930 931 932 933
      end

      # Set the client message level.
      def client_min_messages=(level)
934
        execute("SET client_min_messages TO '#{level}'", 'SCHEMA')
935 936 937 938
      end

      # Returns the sequence name for a table's primary key or some other specified key.
      def default_sequence_name(table_name, pk = nil) #:nodoc:
939 940 941 942 943 944 945 946 947 948
        serial_sequence(table_name, pk || 'id').split('.').last
      rescue ActiveRecord::StatementInvalid
        "#{table_name}_#{pk || 'id'}_seq"
      end

      def serial_sequence(table, column)
        result = exec_query(<<-eosql, 'SCHEMA', [[nil, table], [nil, column]])
          SELECT pg_get_serial_sequence($1, $2)
        eosql
        result.rows.first.first
949 950
      end

951 952
      # Resets the sequence of a table's primary key to the maximum value.
      def reset_pk_sequence!(table, pk = nil, sequence = nil) #:nodoc:
953 954
        unless pk and sequence
          default_pk, default_sequence = pk_and_sequence_for(table)
955

956 957 958
          pk ||= default_pk
          sequence ||= default_sequence
        end
959

960 961 962 963 964
        if @logger && pk && !sequence
          @logger.warn "#{table} has primary key #{pk} with no default sequence"
        end

        if pk && sequence
965
          quoted_sequence = quote_table_name(sequence)
G
Guillermo Iguaran 已提交
966

967 968 969
          select_value <<-end_sql, 'Reset sequence'
            SELECT setval('#{quoted_sequence}', (SELECT COALESCE(MAX(#{quote_column_name pk})+(SELECT increment_by FROM #{quoted_sequence}), (SELECT min_value FROM #{quoted_sequence})) FROM #{quote_table_name(table)}), false)
          end_sql
970 971 972
        end
      end

973 974
      # Returns a table's primary key and belonging sequence.
      def pk_and_sequence_for(table) #:nodoc:
975 976
        # First try looking for a sequence with a dependency on the
        # given table's primary key.
977
        result = exec_query(<<-end_sql, 'SCHEMA').rows.first
978
          SELECT attr.attname, ns.nspname, seq.relname
979
          FROM pg_class seq
A
Akira Matsuda 已提交
980
          INNER JOIN pg_depend dep ON seq.oid = dep.objid
981 982
          INNER JOIN pg_attribute attr ON attr.attrelid = dep.refobjid AND attr.attnum = dep.refobjsubid
          INNER JOIN pg_constraint cons ON attr.attrelid = cons.conrelid AND attr.attnum = cons.conkey[1]
983
          INNER JOIN pg_namespace ns ON seq.relnamespace = ns.oid
984 985 986
          WHERE seq.relkind  = 'S'
            AND cons.contype = 'p'
            AND dep.refobjid = '#{quote_table_name(table)}'::regclass
987
        end_sql
988

989
        # [primary_key, sequence]
990 991 992 993 994
        if result.second ==  'public' then
          sequence = result.last
        else
          sequence = result.second+'.'+result.last
        end
G
Guillermo Iguaran 已提交
995

996
        [result.first, sequence]
997 998
      rescue
        nil
999 1000
      end

1001 1002
      # Returns just a table's primary key
      def primary_key(table)
1003
        row = exec_query(<<-end_sql, 'SCHEMA', [[nil, table]]).rows.first
1004
          SELECT DISTINCT(attr.attname)
1005 1006 1007 1008 1009
          FROM pg_attribute attr
          INNER JOIN pg_depend dep ON attr.attrelid = dep.refobjid AND attr.attnum = dep.refobjsubid
          INNER JOIN pg_constraint cons ON attr.attrelid = cons.conrelid AND attr.attnum = cons.conkey[1]
          WHERE cons.contype = 'p'
            AND dep.refobjid = $1::regclass
1010 1011 1012
        end_sql

        row && row.first
1013 1014
      end

1015
      # Renames a table.
1016 1017 1018
      #
      # Example:
      #   rename_table('octopuses', 'octopi')
1019
      def rename_table(name, new_name)
1020
        clear_cache!
1021
        execute "ALTER TABLE #{quote_table_name(name)} RENAME TO #{quote_table_name(new_name)}"
1022
      end
1023

1024 1025
      # Adds a new column to the named table.
      # See TableDefinition#column for details of the options you can use.
S
Scott Barron 已提交
1026
      def add_column(table_name, column_name, type, options = {})
1027
        clear_cache!
1028 1029
        add_column_sql = "ALTER TABLE #{quote_table_name(table_name)} ADD COLUMN #{quote_column_name(column_name)} #{type_to_sql(type, options[:limit], options[:precision], options[:scale])}"
        add_column_options!(add_column_sql, options)
1030

1031
        execute add_column_sql
S
Scott Barron 已提交
1032
      end
D
Initial  
David Heinemeier Hansson 已提交
1033

1034 1035
      # Changes the column of a table.
      def change_column(table_name, column_name, type, options = {})
1036
        clear_cache!
1037 1038
        quoted_table_name = quote_table_name(table_name)

1039
        execute "ALTER TABLE #{quoted_table_name} ALTER COLUMN #{quote_column_name(column_name)} TYPE #{type_to_sql(type, options[:limit], options[:precision], options[:scale])}"
1040

1041 1042
        change_column_default(table_name, column_name, options[:default]) if options_include_default?(options)
        change_column_null(table_name, column_name, options[:null], options[:default]) if options.key?(:null)
1043
      end
1044

1045 1046
      # Changes the default value of a table column.
      def change_column_default(table_name, column_name, default)
1047
        clear_cache!
1048
        execute "ALTER TABLE #{quote_table_name(table_name)} ALTER COLUMN #{quote_column_name(column_name)} SET DEFAULT #{quote(default)}"
1049
      end
1050

1051
      def change_column_null(table_name, column_name, null, default = nil)
1052
        clear_cache!
1053
        unless null || default.nil?
1054
          execute("UPDATE #{quote_table_name(table_name)} SET #{quote_column_name(column_name)}=#{quote(default)} WHERE #{quote_column_name(column_name)} IS NULL")
1055
        end
1056
        execute("ALTER TABLE #{quote_table_name(table_name)} ALTER #{quote_column_name(column_name)} #{null ? 'DROP' : 'SET'} NOT NULL")
1057 1058
      end

1059 1060
      # Renames a column in a table.
      def rename_column(table_name, column_name, new_column_name)
1061
        clear_cache!
1062
        execute "ALTER TABLE #{quote_table_name(table_name)} RENAME COLUMN #{quote_column_name(column_name)} TO #{quote_column_name(new_column_name)}"
1063
      end
1064

1065 1066 1067 1068
      def remove_index!(table_name, index_name) #:nodoc:
        execute "DROP INDEX #{quote_table_name(index_name)}"
      end

1069 1070 1071 1072
      def rename_index(table_name, old_name, new_name)
        execute "ALTER INDEX #{quote_column_name(old_name)} RENAME TO #{quote_table_name(new_name)}"
      end

1073 1074
      def index_name_length
        63
1075
      end
1076

1077 1078
      # Maps logical Rails types to PostgreSQL-specific data types.
      def type_to_sql(type, limit = nil, precision = nil, scale = nil)
1079
        return super unless type.to_s == 'integer'
1080
        return 'integer' unless limit
1081

1082
        case limit
1083 1084 1085
          when 1, 2; 'smallint'
          when 3, 4; 'integer'
          when 5..8; 'bigint'
1086
          else raise(ActiveRecordError, "No integer type has byte size #{limit}. Use a numeric with precision 0 instead.")
1087 1088
        end
      end
1089

1090
      # Returns a SELECT DISTINCT clause for a given set of columns and a given ORDER BY clause.
1091 1092 1093
      #
      # PostgreSQL requires the ORDER BY columns in the select list for distinct queries, and
      # requires that the ORDER BY include the distinct column.
1094
      #
1095
      #   distinct("posts.id", "posts.created_at desc")
1096 1097
      def distinct(columns, orders) #:nodoc:
        return "DISTINCT #{columns}" if orders.empty?
1098

1099 1100
        # Construct a clean list of column names from the ORDER BY clause, removing
        # any ASC/DESC modifiers
1101
        order_columns = orders.collect { |s| s.gsub(/\s+(ASC|DESC)\s*/i, '') }
1102
        order_columns.delete_if { |c| c.blank? }
1103
        order_columns = order_columns.zip((0...order_columns.size).to_a).map { |s,i| "#{s} AS alias_#{i}" }
1104

1105
        "DISTINCT #{columns}, #{order_columns * ', '}"
1106
      end
1107

1108
      module Utils
1109 1110
        extend self

1111 1112 1113 1114 1115 1116 1117 1118 1119 1120
        # Returns an array of <tt>[schema_name, table_name]</tt> extracted from +name+.
        # +schema_name+ is nil if not specified in +name+.
        # +schema_name+ and +table_name+ exclude surrounding quotes (regardless of whether provided in +name+)
        # +name+ supports the range of schema/table references understood by PostgreSQL, for example:
        #
        # * <tt>table_name</tt>
        # * <tt>"table.name"</tt>
        # * <tt>schema_name.table_name</tt>
        # * <tt>schema_name."table.name"</tt>
        # * <tt>"schema.name"."table name"</tt>
1121
        def extract_schema_and_table(name)
1122 1123 1124 1125 1126
          table, schema = name.scan(/[^".\s]+|"[^"]*"/)[0..1].collect{|m| m.gsub(/(^"|"$)/,'') }.reverse
          [schema, table]
        end
      end

1127
      protected
1128
        # Returns the version of the connected PostgreSQL server.
1129
        def postgresql_version
1130
          @connection.server_version
1131 1132
        end

1133 1134 1135
        def translate_exception(exception, message)
          case exception.message
          when /duplicate key value violates unique constraint/
1136
            RecordNotUnique.new(message, exception)
1137
          when /violates foreign key constraint/
1138
            InvalidForeignKey.new(message, exception)
1139 1140 1141 1142 1143
          else
            super
          end
        end

D
Initial  
David Heinemeier Hansson 已提交
1144
      private
1145 1146
        FEATURE_NOT_SUPPORTED = "0A000" # :nodoc:

1147 1148
        def exec_no_cache(sql, binds)
          @connection.async_exec(sql)
1149
        end
1150

1151
        def exec_cache(sql, binds)
1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186
          begin
            stmt_key = prepare_statement sql

            # Clear the queue
            @connection.get_last_result
            @connection.send_query_prepared(stmt_key, binds.map { |col, val|
              type_cast(val, col)
            })
            @connection.block
            @connection.get_last_result
          rescue PGError => e
            # Get the PG code for the failure.  Annoyingly, the code for
            # prepared statements whose return value may have changed is
            # FEATURE_NOT_SUPPORTED.  Check here for more details:
            # http://git.postgresql.org/gitweb/?p=postgresql.git;a=blob;f=src/backend/utils/cache/plancache.c#l573
            code = e.result.result_error_field(PGresult::PG_DIAG_SQLSTATE)
            if FEATURE_NOT_SUPPORTED == code
              @statements.delete sql_key(sql)
              retry
            else
              raise e
            end
          end
        end

        # Returns the statement identifier for the client side cache
        # of statements
        def sql_key(sql)
          "#{schema_search_path}-#{sql}"
        end

        # Prepare the statement if it hasn't been prepared, return
        # the statement key.
        def prepare_statement(sql)
          sql_key = sql_key(sql)
1187
          unless @statements.key? sql_key
1188
            nextkey = @statements.next_key
1189
            @connection.prepare nextkey, sql
1190
            @statements[sql_key] = nextkey
1191
          end
1192
          @statements[sql_key]
1193
        end
1194

P
Pratik Naik 已提交
1195
        # The internal PostgreSQL identifier of the money data type.
1196
        MONEY_COLUMN_TYPE_OID = 790 #:nodoc:
1197 1198
        # The internal PostgreSQL identifier of the BYTEA data type.
        BYTEA_COLUMN_TYPE_OID = 17 #:nodoc:
1199 1200 1201 1202 1203 1204 1205 1206 1207

        # Connects to a PostgreSQL server and sets up the adapter depending on the
        # connected server's characteristics.
        def connect
          @connection = PGconn.connect(*@connection_parameters)

          # Money type has a fixed precision of 10 in PostgreSQL 8.2 and below, and as of
          # PostgreSQL 8.3 it has a fixed precision of 19. PostgreSQLColumn.extract_precision
          # should know about this but can't detect it there, so deal with it here.
1208 1209
          PostgreSQLColumn.money_precision = (postgresql_version >= 80300) ? 19 : 10

1210 1211 1212
          configure_connection
        end

1213
        # Configures the encoding, verbosity, schema search path, and time zone of the connection.
1214
        # This is called by #connect and should not be called manually.
1215 1216
        def configure_connection
          if @config[:encoding]
1217
            @connection.set_client_encoding(@config[:encoding])
1218
          end
1219 1220
          self.client_min_messages = @config[:min_messages] if @config[:min_messages]
          self.schema_search_path = @config[:schema_search_path] || @config[:schema_order]
1221 1222 1223 1224

          # Use standard-conforming strings if available so we don't have to do the E'...' dance.
          set_standard_conforming_strings

1225
          # If using Active Record's time zone support configure the connection to return
1226
          # TIMESTAMP WITH ZONE types in UTC.
1227
          if ActiveRecord::Base.default_timezone == :utc
1228
            execute("SET time zone 'UTC'", 'SCHEMA')
1229
          elsif @local_tz
1230
            execute("SET time zone '#{@local_tz}'", 'SCHEMA')
1231
          end
1232 1233
        end

1234
        # Returns the current ID of a table's sequence.
1235 1236 1237
        def last_insert_id(sequence_name) #:nodoc:
          r = exec_query("SELECT currval($1)", 'SQL', [[nil, sequence_name]])
          Integer(r.rows.first.first)
D
Initial  
David Heinemeier Hansson 已提交
1238 1239
        end

1240
        # Executes a SELECT query and returns the results, performing any data type
1241
        # conversions that are required to be performed here instead of in PostgreSQLColumn.
1242
        def select(sql, name = nil, binds = [])
A
Aaron Patterson 已提交
1243
          exec_query(sql, name, binds).to_a
1244 1245 1246
        end

        def select_raw(sql, name = nil)
1247
          res = execute(sql, name)
1248
          results = result_as_array(res)
1249
          fields = res.fields
1250
          res.clear
1251
          return fields, results
M
Marcel Molina 已提交
1252 1253
        end

1254
        # Returns the list of a table's column names, data types, and default values.
1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271
        #
        # The underlying query is roughly:
        #  SELECT column.name, column.type, default.value
        #    FROM column LEFT JOIN default
        #      ON column.table_id = default.table_id
        #     AND column.num = default.column_num
        #   WHERE column.table_id = get_table_id('table_name')
        #     AND column.num > 0
        #     AND NOT column.is_dropped
        #   ORDER BY column.num
        #
        # If the table name is not prefixed with a schema, the database will
        # take the first match from the schema search path.
        #
        # Query implementation notes:
        #  - format_type includes the column size constraint, e.g. varchar(50)
        #  - ::regclass is a function that gives the id for a table name
1272
        def column_definitions(table_name) #:nodoc:
1273
          exec_query(<<-end_sql, 'SCHEMA').rows
1274
            SELECT a.attname, format_type(a.atttypid, a.atttypmod), d.adsrc, a.attnotnull
1275 1276
              FROM pg_attribute a LEFT JOIN pg_attrdef d
                ON a.attrelid = d.adrelid AND a.attnum = d.adnum
1277
             WHERE a.attrelid = '#{quote_table_name(table_name)}'::regclass
1278 1279 1280
               AND a.attnum > 0 AND NOT a.attisdropped
             ORDER BY a.attnum
          end_sql
D
Initial  
David Heinemeier Hansson 已提交
1281
        end
1282 1283

        def extract_pg_identifier_from_name(name)
1284
          match_data = name.start_with?('"') ? name.match(/\"([^\"]+)\"/) : name.match(/([^\.]+)/)
1285 1286

          if match_data
1287 1288
            rest = name[match_data[0].length, name.length]
            rest = rest[1, rest.length] if rest.start_with? "."
J
José Valim 已提交
1289
            [match_data[1], (rest.length > 0 ? rest : nil)]
1290 1291
          end
        end
1292

1293 1294 1295 1296 1297
        def extract_table_ref_from_insert_sql(sql)
          sql[/into\s+([^\(]*).*values\s*\(/i]
          $1.strip if $1
        end

1298 1299 1300
        def table_definition
          TableDefinition.new(self)
        end
D
Initial  
David Heinemeier Hansson 已提交
1301 1302 1303
    end
  end
end