postgresql_adapter.rb 51.5 KB
Newer Older
D
Initial  
David Heinemeier Hansson 已提交
1
require 'active_record/connection_adapters/abstract_adapter'
2
require 'active_support/core_ext/object/blank'
3
require 'active_record/connection_adapters/statement_pool'
4
require 'active_record/connection_adapters/postgresql/oid'
5
require 'arel/visitors/bind_visitor'
6 7 8

# Make sure we're using pg high enough for PGResult#values
gem 'pg', '~> 0.11'
9
require 'pg'
D
Initial  
David Heinemeier Hansson 已提交
10

D
Dan McClain 已提交
11 12
require 'ipaddr'

D
Initial  
David Heinemeier Hansson 已提交
13
module ActiveRecord
14
  module ConnectionHandling
D
Initial  
David Heinemeier Hansson 已提交
15
    # Establishes a connection to the database that's used by all Active Record objects
J
Jon Leighton 已提交
16
    def postgresql_connection(config) # :nodoc:
17
      conn_params = config.symbolize_keys
D
Initial  
David Heinemeier Hansson 已提交
18

19
      # Forward any unused config params to PGconn.connect.
20
      [:statement_limit, :encoding, :min_messages, :schema_search_path,
21
       :schema_order, :adapter, :pool, :checkout_timeout, :template,
22
       :reaping_frequency, :insert_returning].each do |key|
23 24
        conn_params.delete key
      end
25
      conn_params.delete_if { |k,v| v.nil? }
26 27 28 29

      # Map ActiveRecords param names to PGs.
      conn_params[:user] = conn_params.delete(:username) if conn_params[:username]
      conn_params[:dbname] = conn_params.delete(:database) if conn_params[:database]
D
Initial  
David Heinemeier Hansson 已提交
30

31
      # The postgres drivers don't allow the creation of an unconnected PGconn object,
32
      # so just pass a nil connection object for the time being.
33
      ConnectionAdapters::PostgreSQLAdapter.new(nil, logger, conn_params, config)
34 35
    end
  end
36

37 38 39 40
  module ConnectionAdapters
    # PostgreSQL-specific extensions to column definitions in a table.
    class PostgreSQLColumn < Column #:nodoc:
      # Instantiates a new PostgreSQL column definition in a table.
41 42
      def initialize(name, default, oid_type, sql_type = nil, null = true)
        @oid_type = oid_type
43 44
        super(name, self.class.extract_value_from_default(default), sql_type, null)
      end
45

46 47 48
      # :stopdoc:
      class << self
        attr_accessor :money_precision
49 50 51 52 53 54 55 56 57 58
        def string_to_time(string)
          return string unless String === string

          case string
          when 'infinity'  then 1.0 / 0.0
          when '-infinity' then -1.0 / 0.0
          else
            super
          end
        end
59

60
        def hstore_to_string(object)
61
          if Hash === object
A
Aaron Patterson 已提交
62 63
            object.map { |k,v|
              "#{escape_hstore(k)}=>#{escape_hstore(v)}"
64
            }.join ','
65
          else
66
            object
A
Aaron Patterson 已提交
67 68 69
          end
        end

70 71 72 73 74 75 76 77 78 79 80
        def string_to_hstore(string)
          if string.nil?
            nil
          elsif String === string
            Hash[string.scan(HstorePair).map { |k,v|
              v = v.upcase == 'NULL' ? nil : v.gsub(/^"(.*)"$/,'\1').gsub(/\\(.)/, '\1')
              k = k.gsub(/^"(.*)"$/,'\1').gsub(/\\(.)/, '\1')
              [k,v]
            }]
          else
            string
A
Aaron Patterson 已提交
81 82 83
          end
        end

84 85 86 87
        def string_to_cidr(string)
          if string.nil?
            nil
          elsif String === string
D
Dan McClain 已提交
88
            IPAddr.new(string)
89 90 91 92 93 94 95
          else
            string
          end

        end

        def cidr_to_string(object)
D
Dan McClain 已提交
96 97
          if IPAddr === object
            "#{object.to_s}/#{object.instance_variable_get(:@mask_addr).to_s(2).count('1')}"
98 99 100 101 102
          else
            object
          end
        end

103
        private
104 105 106 107
        HstorePair = begin
          quoted_string = /"[^"\\]*(?:\\.[^"\\]*)*"/
          unquoted_string = /(?:\\.|[^\s,])[^\s=,\\]*(?:\\.[^\s=,\\]*|=[^,>])*/
          /(#{quoted_string}|#{unquoted_string})\s*=>\s*(#{quoted_string}|#{unquoted_string})/
A
Aaron Patterson 已提交
108 109 110
        end

        def escape_hstore(value)
111 112
            value.nil?         ? 'NULL'
          : value == ""        ? '""'
R
Ryan Fitzgerald 已提交
113
          :                      '"%s"' % value.to_s.gsub(/(["\\])/, '\\\\\1')
114
        end
115 116 117
      end
      # :startdoc:

118 119 120 121 122 123 124 125 126 127 128 129 130 131
      # Extracts the value from a PostgreSQL column default definition.
      def self.extract_value_from_default(default)
        # This is a performance optimization for Ruby 1.9.2 in development.
        # If the value is nil, we return nil straight away without checking
        # the regular expressions. If we check each regular expression,
        # Regexp#=== will call NilClass#to_str, which will trigger
        # method_missing (defined by whiny nil in ActiveSupport) which
        # makes this method very very slow.
        return default unless default

        case default
          # Numeric types
          when /\A\(?(-?\d+(\.\d*)?\)?)\z/
            $1
132
          # Character types
133 134 135 136 137
          when /\A'(.*)'::(?:character varying|bpchar|text)\z/m
            $1
          # Character types (8.1 formatting)
          when /\AE'(.*)'::(?:character varying|bpchar|text)\z/m
            $1.gsub(/\\(\d\d\d)/) { $1.oct.chr }
138
          # Binary data types
139 140
          when /\A'(.*)'::bytea\z/m
            $1
141
          # Date/time types
142 143 144 145 146 147 148 149 150
          when /\A'(.+)'::(?:time(?:stamp)? with(?:out)? time zone|date)\z/
            $1
          when /\A'(.*)'::interval\z/
            $1
          # Boolean type
          when 'true'
            true
          when 'false'
            false
151
          # Geometric types
152 153
          when /\A'(.*)'::(?:point|line|lseg|box|"?path"?|polygon|circle)\z/
            $1
154
          # Network address types
155 156 157 158 159
          when /\A'(.*)'::(?:cidr|inet|macaddr)\z/
            $1
          # Bit string types
          when /\AB'(.*)'::"?bit(?: varying)?"?\z/
            $1
160
          # XML type
161 162
          when /\A'(.*)'::xml\z/m
            $1
163
          # Arrays
164 165
          when /\A'(.*)'::"?\D+"?\[\]\z/
            $1
166 167 168
          # Hstore
          when /\A'(.*)'::hstore\z/
            $1
169
          # Object identifier types
170 171
          when /\A-?\d+\z/
            $1
172
          else
173 174 175
            # Anything else is blank, some user type, or some function
            # and we can't know the value of that, so return nil.
            nil
176
        end
177
      end
178

179 180 181 182 183 184 185
      def type_cast(value)
        return if value.nil?
        return super if encoded?

        @oid_type.type_cast value
      end

186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220
      private
      def extract_limit(sql_type)
        case sql_type
        when /^bigint/i;    8
        when /^smallint/i;  2
        else super
        end
      end

      # Extracts the scale from PostgreSQL-specific data types.
      def extract_scale(sql_type)
        # Money type has a fixed scale of 2.
        sql_type =~ /^money/ ? 2 : super
      end

      # Extracts the precision from PostgreSQL-specific data types.
      def extract_precision(sql_type)
        if sql_type == 'money'
          self.class.money_precision
        else
          super
        end
      end

      # Maps PostgreSQL-specific data types to logical Rails types.
      def simplified_type(field_type)
        case field_type
        # Numeric and monetary types
        when /^(?:real|double precision)$/
          :float
        # Monetary types
        when 'money'
          :decimal
        when 'hstore'
          :hstore
221 222 223 224 225 226 227
        # Network address types
        when 'inet'
          :inet
        when 'cidr'
          :cidr
        when 'macaddr'
          :macaddr
228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265
        # Character types
        when /^(?:character varying|bpchar)(?:\(\d+\))?$/
          :string
        # Binary data types
        when 'bytea'
          :binary
        # Date/time types
        when /^timestamp with(?:out)? time zone$/
          :datetime
        when 'interval'
          :string
        # Geometric types
        when /^(?:point|line|lseg|box|"?path"?|polygon|circle)$/
          :string
        # Bit strings
        when /^bit(?: varying)?(?:\(\d+\))?$/
          :string
        # XML type
        when 'xml'
          :xml
        # tsvector type
        when 'tsvector'
          :tsvector
        # Arrays
        when /^\D+\[\]$/
          :string
        # Object identifier types
        when 'oid'
          :integer
        # UUID type
        when 'uuid'
          :string
        # Small and big integer types
        when /^(?:small|big)int$/
          :integer
        # Pass through all types that are not specific to PostgreSQL.
        else
          super
266
        end
267
      end
D
Initial  
David Heinemeier Hansson 已提交
268 269
    end

270
    # The PostgreSQL adapter works with the native C (https://bitbucket.org/ged/ruby-pg) driver.
271 272 273
    #
    # Options:
    #
274 275
    # * <tt>:host</tt> - Defaults to a Unix-domain socket in /tmp. On machines without Unix-domain sockets,
    #   the default is to connect to localhost.
P
Pratik Naik 已提交
276
    # * <tt>:port</tt> - Defaults to 5432.
277 278 279
    # * <tt>:username</tt> - Defaults to be the same as the operating system name of the user running the application.
    # * <tt>:password</tt> - Password to be used if the server demands password authentication.
    # * <tt>:database</tt> - Defaults to be the same as the user name.
280
    # * <tt>:schema_search_path</tt> - An optional schema search path for the connection given
281
    #   as a string of comma-separated schema names. This is backward-compatible with the <tt>:schema_order</tt> option.
282
    # * <tt>:encoding</tt> - An optional client encoding that is used in a <tt>SET client_encoding TO
283
    #   <encoding></tt> call on the connection.
284
    # * <tt>:min_messages</tt> - An optional client min messages that is used in a
285
    #   <tt>SET client_min_messages TO <min_messages></tt> call on the connection.
286 287
    # * <tt>:insert_returning</tt> - An optional boolean to control the use or <tt>RETURNING</tt> for <tt>INSERT<tt> statements
    #   defaults to true.
288 289 290 291 292 293 294
    #
    # Any further options are used as connection parameters to libpq. See
    # http://www.postgresql.org/docs/9.1/static/libpq-connect.html for the
    # list of parameters.
    #
    # In addition, default connection parameters of libpq can be set per environment variables.
    # See http://www.postgresql.org/docs/9.1/static/libpq-envars.html .
295
    class PostgreSQLAdapter < AbstractAdapter
296 297 298 299 300
      class TableDefinition < ActiveRecord::ConnectionAdapters::TableDefinition
        def xml(*args)
          options = args.extract_options!
          column(args[0], 'xml', options)
        end
301 302 303 304 305

        def tsvector(*args)
          options = args.extract_options!
          column(args[0], 'tsvector', options)
        end
306 307 308 309

        def hstore(name, options = {})
          column(name, 'hstore', options)
        end
310 311 312 313 314 315 316 317 318 319 320 321

        def inet(name, options = {})
          column(name, 'inet', options)
        end

        def cidr(name, options = {})
          column(name, 'cidr', options)
        end

        def macaddr(name, options = {})
          column(name, 'macaddr', options)
        end
322 323
      end

324
      ADAPTER_NAME = 'PostgreSQL'
325 326

      NATIVE_DATABASE_TYPES = {
327
        :primary_key => "serial primary key",
328 329 330 331 332 333 334 335 336 337
        :string      => { :name => "character varying", :limit => 255 },
        :text        => { :name => "text" },
        :integer     => { :name => "integer" },
        :float       => { :name => "float" },
        :decimal     => { :name => "decimal" },
        :datetime    => { :name => "timestamp" },
        :timestamp   => { :name => "timestamp" },
        :time        => { :name => "time" },
        :date        => { :name => "date" },
        :binary      => { :name => "bytea" },
338
        :boolean     => { :name => "boolean" },
339
        :xml         => { :name => "xml" },
340
        :tsvector    => { :name => "tsvector" },
341 342 343 344
        :hstore      => { :name => "hstore" },
        :inet        => { :name => "inet" },
        :cidr        => { :name => "cidr" },
        :macaddr     => { :name => "macaddr" }
345 346
      }

347
      # Returns 'PostgreSQL' as adapter name for identification purposes.
348
      def adapter_name
349
        ADAPTER_NAME
350 351
      end

352 353
      # Returns +true+, since this connection adapter supports prepared statement
      # caching.
354 355 356 357
      def supports_statement_cache?
        true
      end

358 359 360 361
      def supports_index_sort_order?
        true
      end

362 363 364 365
      def supports_partial_index?
        true
      end

366 367 368 369
      class StatementPool < ConnectionAdapters::StatementPool
        def initialize(connection, max)
          super
          @counter = 0
370
          @cache   = Hash.new { |h,pid| h[pid] = {} }
371 372
        end

373 374 375 376
        def each(&block); cache.each(&block); end
        def key?(key);    cache.key?(key); end
        def [](key);      cache[key]; end
        def length;       cache.length; end
377 378 379 380 381 382

        def next_key
          "a#{@counter + 1}"
        end

        def []=(sql, key)
383 384
          while @max <= cache.size
            dealloc(cache.shift.last)
385 386
          end
          @counter += 1
387
          cache[sql] = key
388 389 390
        end

        def clear
391
          cache.each_value do |stmt_key|
392 393
            dealloc stmt_key
          end
394
          cache.clear
395 396
        end

397 398 399 400 401
        def delete(sql_key)
          dealloc cache[sql_key]
          cache.delete sql_key
        end

402
        private
403
        def cache
A
Aaron Patterson 已提交
404
          @cache[Process.pid]
405 406
        end

407
        def dealloc(key)
408 409 410 411 412 413 414
          @connection.query "DEALLOCATE #{key}" if connection_active?
        end

        def connection_active?
          @connection.status == PGconn::CONNECTION_OK
        rescue PGError
          false
415 416 417
        end
      end

418 419 420 421
      class BindSubstitution < Arel::Visitors::PostgreSQL # :nodoc:
        include Arel::Visitors::BindVisitor
      end

422 423
      # Initializes and connects a PostgreSQL adapter.
      def initialize(connection, logger, connection_parameters, config)
424
        super(connection, logger)
425 426 427 428 429 430 431 432 433

        if config.fetch(:prepared_statements) { true }
          @visitor = Arel::Visitors::PostgreSQL.new self
        else
          @visitor = BindSubstitution.new self
        end

        connection_parameters.delete :prepared_statements

434
        @connection_parameters, @config = connection_parameters, config
435

436 437
        # @local_tz is initialized as nil to avoid warnings when connect tries to use it
        @local_tz = nil
438 439
        @table_alias_length = nil

440
        connect
441 442
        @statements = StatementPool.new @connection,
                                        config.fetch(:statement_limit) { 1000 }
443 444 445 446 447

        if postgresql_version < 80200
          raise "Your version of PostgreSQL (#{postgresql_version}) is too old, please upgrade!"
        end

448
        initialize_type_map
449
        @local_tz = execute('SHOW TIME ZONE', 'SCHEMA').first["TimeZone"]
450
        @use_insert_returning = @config.key?(:insert_returning) ? @config[:insert_returning] : true
451 452
      end

X
Xavier Noria 已提交
453
      # Clears the prepared statements cache.
454 455 456 457
      def clear_cache!
        @statements.clear
      end

458 459
      # Is this connection alive and ready for queries?
      def active?
460 461
        @connection.query 'SELECT 1'
        true
462
      rescue PGError
463
        false
464 465 466 467
      end

      # Close then reopen the connection.
      def reconnect!
468 469 470
        clear_cache!
        @connection.reset
        configure_connection
471
      end
472

473 474 475 476 477
      def reset!
        clear_cache!
        super
      end

478 479
      # Disconnects from the database if already connected. Otherwise, this
      # method does nothing.
480
      def disconnect!
481
        clear_cache!
482 483
        @connection.close rescue nil
      end
484

485
      def native_database_types #:nodoc:
486
        NATIVE_DATABASE_TYPES
487
      end
488

489
      # Returns true, since this connection adapter supports migrations.
490 491
      def supports_migrations?
        true
492 493
      end

494
      # Does PostgreSQL support finding primary key on non-Active Record tables?
495 496 497 498
      def supports_primary_key? #:nodoc:
        true
      end

499 500 501
      # Enable standard-conforming strings if available.
      def set_standard_conforming_strings
        old, self.client_min_messages = client_min_messages, 'panic'
502
        execute('SET standard_conforming_strings = on', 'SCHEMA') rescue nil
503 504
      ensure
        self.client_min_messages = old
505 506
      end

507
      def supports_insert_with_returning?
508
        true
509 510
      end

511 512 513
      def supports_ddl_transactions?
        true
      end
514

515
      # Returns true, since this connection adapter supports savepoints.
516 517 518
      def supports_savepoints?
        true
      end
519

520 521 522 523 524
      # Returns true.
      def supports_explain?
        true
      end

525
      # Returns the configured supported identifier length supported by PostgreSQL
526
      def table_alias_length
K
kennyj 已提交
527
        @table_alias_length ||= query('SHOW max_identifier_length', 'SCHEMA')[0][0].to_i
528
      end
529

530 531
      # QUOTING ==================================================

532
      # Escapes binary strings for bytea input to the database.
533
      def escape_bytea(value)
534
        PGconn.escape_bytea(value) if value
535 536 537 538 539
      end

      # Unescapes bytea output from a database to the binary string it represents.
      # NOTE: This is NOT an inverse of escape_bytea! This is only to be used
      #       on escaped binary output from database drive.
540
      def unescape_bytea(value)
541
        PGconn.unescape_bytea(value) if value
542 543
      end

544 545
      # Quotes PostgreSQL-specific data types for SQL input.
      def quote(value, column = nil) #:nodoc:
546 547
        return super unless column

A
Aaron Patterson 已提交
548
        case value
549 550 551 552 553
        when Hash
          case column.sql_type
          when 'hstore' then super(PostgreSQLColumn.hstore_to_string(value), column)
          else super
          end
D
Dan McClain 已提交
554
        when IPAddr
555 556 557 558
          case column.sql_type
          when 'inet', 'cidr' then super(PostgreSQLColumn.cidr_to_string(value), column)
          else super
          end
559
        when Float
560 561 562 563 564 565 566
          if value.infinite? && column.type == :datetime
            "'#{value.to_s.downcase}'"
          elsif value.infinite? || value.nan?
            "'#{value.to_s}'"
          else
            super
          end
A
Aaron Patterson 已提交
567 568
        when Numeric
          return super unless column.sql_type == 'money'
569
          # Not truly string input, so doesn't require (or allow) escape string syntax.
570
          "'#{value}'"
A
Aaron Patterson 已提交
571 572 573 574 575 576 577 578 579 580 581
        when String
          case column.sql_type
          when 'bytea' then "'#{escape_bytea(value)}'"
          when 'xml'   then "xml '#{quote_string(value)}'"
          when /^bit/
            case value
            when /^[01]*$/      then "B'#{value}'" # Bit-string notation
            when /^[0-9A-F]*$/i then "X'#{value}'" # Hexadecimal notation
            end
          else
            super
582
          end
583 584 585 586 587
        else
          super
        end
      end

588 589 590 591 592 593
      def type_cast(value, column)
        return super unless column

        case value
        when String
          return super unless 'bytea' == column.sql_type
594
          { :value => value, :format => 1 }
595 596 597
        when Hash
          return super unless 'hstore' == column.sql_type
          PostgreSQLColumn.hstore_to_string(value)
D
Dan McClain 已提交
598
        when IPAddr
599 600
          return super unless ['inet','cidr'].includes? column.sql_type
          PostgreSQLColumn.cidr_to_string(value)
601 602 603 604 605
        else
          super
        end
      end

606 607 608
      # Quotes strings for use in SQL input.
      def quote_string(s) #:nodoc:
        @connection.escape(s)
609 610
      end

611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629
      # Checks the following cases:
      #
      # - table_name
      # - "table.name"
      # - schema_name.table_name
      # - schema_name."table.name"
      # - "schema.name".table_name
      # - "schema.name"."table.name"
      def quote_table_name(name)
        schema, name_part = extract_pg_identifier_from_name(name.to_s)

        unless name_part
          quote_column_name(schema)
        else
          table_name, name_part = extract_pg_identifier_from_name(name_part)
          "#{quote_column_name(schema)}.#{quote_column_name(table_name)}"
        end
      end

630 631
      # Quotes column names for use in SQL queries.
      def quote_column_name(name) #:nodoc:
632
        PGconn.quote_ident(name.to_s)
633 634
      end

635 636 637
      # Quote date/time values for use in SQL input. Includes microseconds
      # if the value is a Time responding to usec.
      def quoted_date(value) #:nodoc:
638 639 640 641 642
        if value.acts_like?(:time) && value.respond_to?(:usec)
          "#{super}.#{sprintf("%06d", value.usec)}"
        else
          super
        end
643 644
      end

645 646
      # Set the authorized user for this session
      def session_auth=(user)
647
        clear_cache!
A
Aaron Patterson 已提交
648
        exec_query "SET SESSION AUTHORIZATION #{user}"
649 650
      end

651 652
      # REFERENTIAL INTEGRITY ====================================

653
      def supports_disable_referential_integrity? #:nodoc:
654
        true
655 656
      end

657
      def disable_referential_integrity #:nodoc:
658
        if supports_disable_referential_integrity? then
659 660
          execute(tables.collect { |name| "ALTER TABLE #{quote_table_name(name)} DISABLE TRIGGER ALL" }.join(";"))
        end
661 662
        yield
      ensure
663
        if supports_disable_referential_integrity? then
664 665
          execute(tables.collect { |name| "ALTER TABLE #{quote_table_name(name)} ENABLE TRIGGER ALL" }.join(";"))
        end
666
      end
667 668 669

      # DATABASE STATEMENTS ======================================

670
      def explain(arel, binds = [])
671
        sql = "EXPLAIN #{to_sql(arel, binds)}"
672
        ExplainPrettyPrinter.new.pp(exec_query(sql, 'EXPLAIN', binds))
X
Xavier Noria 已提交
673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711
      end

      class ExplainPrettyPrinter # :nodoc:
        # Pretty prints the result of a EXPLAIN in a way that resembles the output of the
        # PostgreSQL shell:
        #
        #                                     QUERY PLAN
        #   ------------------------------------------------------------------------------
        #    Nested Loop Left Join  (cost=0.00..37.24 rows=8 width=0)
        #      Join Filter: (posts.user_id = users.id)
        #      ->  Index Scan using users_pkey on users  (cost=0.00..8.27 rows=1 width=4)
        #            Index Cond: (id = 1)
        #      ->  Seq Scan on posts  (cost=0.00..28.88 rows=8 width=4)
        #            Filter: (posts.user_id = 1)
        #   (6 rows)
        #
        def pp(result)
          header = result.columns.first
          lines  = result.rows.map(&:first)

          # We add 2 because there's one char of padding at both sides, note
          # the extra hyphens in the example above.
          width = [header, *lines].map(&:length).max + 2

          pp = []

          pp << header.center(width).rstrip
          pp << '-' * width

          pp += lines.map {|line| " #{line}"}

          nrows = result.rows.length
          rows_label = nrows == 1 ? 'row' : 'rows'
          pp << "(#{nrows} #{rows_label})"

          pp.join("\n") + "\n"
        end
      end

712 713 714 715 716 717
      # Executes a SELECT query and returns an array of rows. Each row is an
      # array of field values.
      def select_rows(sql, name = nil)
        select_raw(sql, name).last
      end

718
      # Executes an INSERT query and returns the new record's ID
719
      def insert_sql(sql, name = nil, pk = nil, id_value = nil, sequence_name = nil)
720 721 722 723 724
        unless pk
          # Extract the table from the insert sql. Yuck.
          table_ref = extract_table_ref_from_insert_sql(sql)
          pk = primary_key(table_ref) if table_ref
        end
725

726
        if pk && use_insert_returning?
727
          select_value("#{sql} RETURNING #{quote_column_name(pk)}")
728 729 730
        elsif pk
          super
          last_insert_id_value(sequence_name || default_sequence_name(table_ref, pk))
731 732
        else
          super
733
        end
734
      end
735
      alias :create :insert
736

737 738
      # create a 2D array representing the result set
      def result_as_array(res) #:nodoc:
739
        # check if we have any binary column and if they need escaping
740 741
        ftypes = Array.new(res.nfields) do |i|
          [i, res.ftype(i)]
742 743
        end

744 745 746 747 748 749
        rows = res.values
        return rows unless ftypes.any? { |_, x|
          x == BYTEA_COLUMN_TYPE_OID || x == MONEY_COLUMN_TYPE_OID
        }

        typehash = ftypes.group_by { |_, type| type }
750 751
        binaries = typehash[BYTEA_COLUMN_TYPE_OID] || []
        monies   = typehash[MONEY_COLUMN_TYPE_OID] || []
752 753 754

        rows.each do |row|
          # unescape string passed BYTEA field (OID == 17)
755 756
          binaries.each do |index, _|
            row[index] = unescape_bytea(row[index])
757 758 759 760 761 762
          end

          # If this is a money type column and there are any currency symbols,
          # then strip them off. Indeed it would be prettier to do this in
          # PostgreSQLColumn.string_to_decimal but would break form input
          # fields that call value_before_type_cast.
763
          monies.each do |index, _|
764 765 766 767 768 769 770 771 772 773
            data = row[index]
            # Because money output is formatted according to the locale, there are two
            # cases to consider (note the decimal separators):
            #  (1) $12,345,678.12
            #  (2) $12.345.678,12
            case data
            when /^-?\D+[\d,]+\.\d{2}$/  # (1)
              data.gsub!(/[^-\d.]/, '')
            when /^-?\D+[\d.]+,\d{2}$/  # (2)
              data.gsub!(/[^-\d,]/, '').sub!(/,/, '.')
774
            end
775 776 777 778 779 780
          end
        end
      end


      # Queries the database and returns the results in an Array-like object
781
      def query(sql, name = nil) #:nodoc:
782
        log(sql, name) do
783
          result_as_array @connection.async_exec(sql)
784
        end
785 786
      end

787
      # Executes an SQL statement, returning a PGresult object on success
788 789
      # or raising a PGError exception otherwise.
      def execute(sql, name = nil)
790
        log(sql, name) do
791
          @connection.async_exec(sql)
792
        end
793 794
      end

795
      def substitute_at(column, index)
796
        Arel::Nodes::BindParam.new "$#{index + 1}"
797 798
      end

799 800 801 802 803 804 805
      class Result < ActiveRecord::Result
        def initialize(columns, rows, column_types)
          super(columns, rows)
          @column_types = column_types
        end
      end

A
Aaron Patterson 已提交
806
      def exec_query(sql, name = 'SQL', binds = [])
807
        log(sql, name, binds) do
808 809
          result = binds.empty? ? exec_no_cache(sql, binds) :
                                  exec_cache(sql, binds)
810

811 812 813
          types = {}
          result.fields.each_with_index do |fname, i|
            ftype = result.ftype i
814 815
            fmod  = result.fmod i
            types[fname] = OID::TYPE_MAP.fetch(ftype, fmod) { |oid, mod|
816 817 818 819 820
              warn "unknown OID: #{fname}(#{oid}) (#{sql})"
              OID::Identity.new
            }
          end

821
          ret = Result.new(result.fields, result.values, types)
822 823
          result.clear
          return ret
824 825 826
        end
      end

827 828 829 830 831 832 833 834 835
      def exec_delete(sql, name = 'SQL', binds = [])
        log(sql, name, binds) do
          result = binds.empty? ? exec_no_cache(sql, binds) :
                                  exec_cache(sql, binds)
          affected = result.cmd_tuples
          result.clear
          affected
        end
      end
836
      alias :exec_update :exec_delete
837

838 839
      def sql_for_insert(sql, pk, id_value, sequence_name, binds)
        unless pk
840 841 842
          # Extract the table from the insert sql. Yuck.
          table_ref = extract_table_ref_from_insert_sql(sql)
          pk = primary_key(table_ref) if table_ref
843 844
        end

845
        if pk && use_insert_returning?
846 847
          sql = "#{sql} RETURNING #{quote_column_name(pk)}"
        end
848 849 850 851

        [sql, binds]
      end

852 853
      def exec_insert(sql, name, binds, pk = nil, sequence_name = nil)
        val = exec_query(sql, name, binds)
854
        if !use_insert_returning? && pk
D
Doug Cole 已提交
855
          unless sequence_name
856 857 858 859
            table_ref = extract_table_ref_from_insert_sql(sql)
            sequence_name = default_sequence_name(table_ref, pk)
            return val unless sequence_name
          end
D
Doug Cole 已提交
860
          last_insert_id_result(sequence_name)
861 862 863 864 865
        else
          val
        end
      end

866
      # Executes an UPDATE query and returns the number of affected tuples.
867
      def update_sql(sql, name = nil)
868
        super.cmd_tuples
869 870
      end

871 872
      # Begins a transaction.
      def begin_db_transaction
873 874 875
        execute "BEGIN"
      end

876 877
      # Commits a transaction.
      def commit_db_transaction
878 879
        execute "COMMIT"
      end
880

881 882
      # Aborts a transaction.
      def rollback_db_transaction
883 884
        execute "ROLLBACK"
      end
885

886 887
      def outside_transaction?
        @connection.transaction_status == PGconn::PQTRANS_IDLE
888
      end
889

J
Jonathan Viney 已提交
890 891 892 893 894 895 896 897
      def create_savepoint
        execute("SAVEPOINT #{current_savepoint_name}")
      end

      def rollback_to_savepoint
        execute("ROLLBACK TO SAVEPOINT #{current_savepoint_name}")
      end

898
      def release_savepoint
J
Jonathan Viney 已提交
899 900
        execute("RELEASE SAVEPOINT #{current_savepoint_name}")
      end
901

902 903
      # SCHEMA STATEMENTS ========================================

904 905 906
      # Drops the database specified on the +name+ attribute
      # and creates it again using the provided +options+.
      def recreate_database(name, options = {}) #:nodoc:
907
        drop_database(name)
908
        create_database(name, options)
909 910
      end

911
      # Create a new PostgreSQL database. Options include <tt>:owner</tt>, <tt>:template</tt>,
912 913
      # <tt>:encoding</tt>, <tt>:tablespace</tt>, and <tt>:connection_limit</tt> (note that MySQL uses
      # <tt>:charset</tt> while PostgreSQL uses <tt>:encoding</tt>).
914 915 916 917 918 919 920 921 922 923
      #
      # Example:
      #   create_database config[:database], config
      #   create_database 'foo_development', :encoding => 'unicode'
      def create_database(name, options = {})
        options = options.reverse_merge(:encoding => "utf8")

        option_string = options.symbolize_keys.sum do |key, value|
          case key
          when :owner
924
            " OWNER = \"#{value}\""
925
          when :template
926
            " TEMPLATE = \"#{value}\""
927 928 929
          when :encoding
            " ENCODING = '#{value}'"
          when :tablespace
930
            " TABLESPACE = \"#{value}\""
931 932 933 934 935 936 937
          when :connection_limit
            " CONNECTION LIMIT = #{value}"
          else
            ""
          end
        end

938
        execute "CREATE DATABASE #{quote_table_name(name)}#{option_string}"
939 940
      end

941
      # Drops a PostgreSQL database.
942 943 944 945
      #
      # Example:
      #   drop_database 'matt_development'
      def drop_database(name) #:nodoc:
946
        execute "DROP DATABASE IF EXISTS #{quote_table_name(name)}"
947 948
      end

949 950
      # Returns the list of all tables in the schema search path or a specified schema.
      def tables(name = nil)
951
        query(<<-SQL, 'SCHEMA').map { |row| row[0] }
952
          SELECT tablename
953 954 955 956 957
          FROM pg_tables
          WHERE schemaname = ANY (current_schemas(false))
        SQL
      end

958
      # Returns true if table exists.
959 960
      # If the schema is not specified as part of +name+ then it will only find tables within
      # the current schema search path (regardless of permissions to access tables in other schemas)
961
      def table_exists?(name)
962
        schema, table = Utils.extract_schema_and_table(name.to_s)
963
        return false unless table
964

965 966
        binds = [[nil, table]]
        binds << [nil, schema] if schema
967

968
        exec_query(<<-SQL, 'SCHEMA').rows.first[0].to_i > 0
969
            SELECT COUNT(*)
A
Aaron Patterson 已提交
970 971 972
            FROM pg_class c
            LEFT JOIN pg_namespace n ON n.oid = c.relnamespace
            WHERE c.relkind in ('v','r')
973 974
            AND c.relname = '#{table.gsub(/(^"|"$)/,'')}'
            AND n.nspname = #{schema ? "'#{schema}'" : 'ANY (current_schemas(false))'}
975 976 977
        SQL
      end

978 979
      # Returns true if schema exists.
      def schema_exists?(name)
980
        exec_query(<<-SQL, 'SCHEMA').rows.first[0].to_i > 0
981 982
          SELECT COUNT(*)
          FROM pg_namespace
983
          WHERE nspname = '#{name}'
984 985
        SQL
      end
986

987
      # Returns an array of indexes for the given table.
988
      def indexes(table_name, name = nil)
K
kennyj 已提交
989
         result = query(<<-SQL, 'SCHEMA')
990
           SELECT distinct i.relname, d.indisunique, d.indkey, pg_get_indexdef(d.indexrelid), t.oid
991 992 993
           FROM pg_class t
           INNER JOIN pg_index d ON t.oid = d.indrelid
           INNER JOIN pg_class i ON d.indexrelid = i.oid
994 995 996
           WHERE i.relkind = 'i'
             AND d.indisprimary = 'f'
             AND t.relname = '#{table_name}'
997
             AND i.relnamespace IN (SELECT oid FROM pg_namespace WHERE nspname = ANY (current_schemas(false)) )
998 999 1000
          ORDER BY i.relname
        SQL

1001
        result.map do |row|
1002 1003 1004
          index_name = row[0]
          unique = row[1] == 't'
          indkey = row[2].split(" ")
1005 1006
          inddef = row[3]
          oid = row[4]
1007

1008 1009
          columns = Hash[query(<<-SQL, "Columns for index #{row[0]} on #{table_name}")]
          SELECT a.attnum, a.attname
1010 1011 1012 1013 1014
          FROM pg_attribute a
          WHERE a.attrelid = #{oid}
          AND a.attnum IN (#{indkey.join(",")})
          SQL

1015
          column_names = columns.values_at(*indkey).compact
1016 1017 1018 1019

          # add info on sort order for columns (only desc order is explicitly specified, asc is the default)
          desc_order_columns = inddef.scan(/(\w+) DESC/).flatten
          orders = desc_order_columns.any? ? Hash[desc_order_columns.map {|order_column| [order_column, :desc]}] : {}
1020
          where = inddef.scan(/WHERE (.+)$/).flatten[0]
J
Jon Leighton 已提交
1021

1022
          column_names.empty? ? nil : IndexDefinition.new(table_name, index_name, unique, column_names, [], orders, where)
1023
        end.compact
1024 1025
      end

1026
      # Returns the list of all column definitions for a table.
1027
      def columns(table_name)
1028
        # Limit, precision, and scale are all handled by the superclass.
1029 1030 1031 1032 1033
        column_definitions(table_name).map do |column_name, type, default, notnull, oid, fmod|
          oid = OID::TYPE_MAP.fetch(oid.to_i, fmod.to_i) {
            OID::Identity.new
          }
          PostgreSQLColumn.new(column_name, default, oid, type, notnull == 'f')
D
Initial  
David Heinemeier Hansson 已提交
1034 1035 1036
        end
      end

1037 1038
      # Returns the current database name.
      def current_database
K
kennyj 已提交
1039
        query('select current_database()', 'SCHEMA')[0][0]
1040 1041
      end

1042 1043 1044 1045 1046
      # Returns the current schema name.
      def current_schema
        query('SELECT current_schema', 'SCHEMA')[0][0]
      end

1047 1048
      # Returns the current database encoding format.
      def encoding
K
kennyj 已提交
1049
        query(<<-end_sql, 'SCHEMA')[0][0]
1050 1051 1052 1053 1054
          SELECT pg_encoding_to_char(pg_database.encoding) FROM pg_database
          WHERE pg_database.datname LIKE '#{current_database}'
        end_sql
      end

1055 1056
      # Returns an array of schema names.
      def schema_names
K
kennyj 已提交
1057
        query(<<-SQL, 'SCHEMA').flatten
1058 1059 1060 1061 1062 1063 1064 1065
          SELECT nspname
            FROM pg_namespace
           WHERE nspname !~ '^pg_.*'
             AND nspname NOT IN ('information_schema')
           ORDER by nspname;
        SQL
      end

1066 1067 1068 1069 1070 1071 1072 1073 1074 1075
      # Creates a schema for the given schema name.
      def create_schema schema_name
        execute "CREATE SCHEMA #{schema_name}"
      end

      # Drops the schema for the given schema name.
      def drop_schema schema_name
        execute "DROP SCHEMA #{schema_name} CASCADE"
      end

1076 1077 1078 1079 1080 1081
      # Sets the schema search path to a string of comma-separated schema names.
      # Names beginning with $ have to be quoted (e.g. $user => '$user').
      # See: http://www.postgresql.org/docs/current/static/ddl-schemas.html
      #
      # This should be not be called manually but set in database.yml.
      def schema_search_path=(schema_csv)
1082
        if schema_csv
1083
          execute("SET search_path TO #{schema_csv}", 'SCHEMA')
1084
          @schema_search_path = schema_csv
1085
        end
D
Initial  
David Heinemeier Hansson 已提交
1086 1087
      end

1088 1089
      # Returns the active schema search path.
      def schema_search_path
X
Xavier Noria 已提交
1090
        @schema_search_path ||= query('SHOW search_path', 'SCHEMA')[0][0]
1091
      end
1092

1093 1094
      # Returns the current client message level.
      def client_min_messages
1095
        query('SHOW client_min_messages', 'SCHEMA')[0][0]
1096 1097 1098 1099
      end

      # Set the client message level.
      def client_min_messages=(level)
1100
        execute("SET client_min_messages TO '#{level}'", 'SCHEMA')
1101 1102 1103 1104
      end

      # Returns the sequence name for a table's primary key or some other specified key.
      def default_sequence_name(table_name, pk = nil) #:nodoc:
1105 1106 1107
        result = serial_sequence(table_name, pk || 'id')
        return nil unless result
        result.split('.').last
1108 1109 1110 1111 1112
      rescue ActiveRecord::StatementInvalid
        "#{table_name}_#{pk || 'id'}_seq"
      end

      def serial_sequence(table, column)
1113 1114
        result = exec_query(<<-eosql, 'SCHEMA')
          SELECT pg_get_serial_sequence('#{table}', '#{column}')
1115 1116
        eosql
        result.rows.first.first
1117 1118
      end

1119 1120
      # Resets the sequence of a table's primary key to the maximum value.
      def reset_pk_sequence!(table, pk = nil, sequence = nil) #:nodoc:
1121 1122
        unless pk and sequence
          default_pk, default_sequence = pk_and_sequence_for(table)
1123

1124 1125 1126
          pk ||= default_pk
          sequence ||= default_sequence
        end
1127

1128 1129 1130 1131 1132
        if @logger && pk && !sequence
          @logger.warn "#{table} has primary key #{pk} with no default sequence"
        end

        if pk && sequence
1133
          quoted_sequence = quote_table_name(sequence)
G
Guillermo Iguaran 已提交
1134

1135 1136 1137
          select_value <<-end_sql, 'Reset sequence'
            SELECT setval('#{quoted_sequence}', (SELECT COALESCE(MAX(#{quote_column_name pk})+(SELECT increment_by FROM #{quoted_sequence}), (SELECT min_value FROM #{quoted_sequence})) FROM #{quote_table_name(table)}), false)
          end_sql
1138 1139 1140
        end
      end

1141 1142
      # Returns a table's primary key and belonging sequence.
      def pk_and_sequence_for(table) #:nodoc:
1143 1144
        # First try looking for a sequence with a dependency on the
        # given table's primary key.
1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159
        result = query(<<-end_sql, 'PK and serial sequence')[0]
          SELECT attr.attname, seq.relname
          FROM pg_class      seq,
               pg_attribute  attr,
               pg_depend     dep,
               pg_namespace  name,
               pg_constraint cons
          WHERE seq.oid           = dep.objid
            AND seq.relkind       = 'S'
            AND attr.attrelid     = dep.refobjid
            AND attr.attnum       = dep.refobjsubid
            AND attr.attrelid     = cons.conrelid
            AND attr.attnum       = cons.conkey[1]
            AND cons.contype      = 'p'
            AND dep.refobjid      = '#{quote_table_name(table)}'::regclass
1160
        end_sql
1161

1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181
        if result.nil? or result.empty?
          # If that fails, try parsing the primary key's default value.
          # Support the 7.x and 8.0 nextval('foo'::text) as well as
          # the 8.1+ nextval('foo'::regclass).
          result = query(<<-end_sql, 'PK and custom sequence')[0]
            SELECT attr.attname,
              CASE
                WHEN split_part(def.adsrc, '''', 2) ~ '.' THEN
                  substr(split_part(def.adsrc, '''', 2),
                         strpos(split_part(def.adsrc, '''', 2), '.')+1)
                ELSE split_part(def.adsrc, '''', 2)
              END
            FROM pg_class       t
            JOIN pg_attribute   attr ON (t.oid = attrelid)
            JOIN pg_attrdef     def  ON (adrelid = attrelid AND adnum = attnum)
            JOIN pg_constraint  cons ON (conrelid = adrelid AND adnum = conkey[1])
            WHERE t.oid = '#{quote_table_name(table)}'::regclass
              AND cons.contype = 'p'
              AND def.adsrc ~* 'nextval'
          end_sql
1182
        end
G
Guillermo Iguaran 已提交
1183

1184
        [result.first, result.last]
1185 1186
      rescue
        nil
1187 1188
      end

1189 1190
      # Returns just a table's primary key
      def primary_key(table)
1191
        row = exec_query(<<-end_sql, 'SCHEMA').rows.first
1192
          SELECT DISTINCT(attr.attname)
1193 1194 1195 1196
          FROM pg_attribute attr
          INNER JOIN pg_depend dep ON attr.attrelid = dep.refobjid AND attr.attnum = dep.refobjsubid
          INNER JOIN pg_constraint cons ON attr.attrelid = cons.conrelid AND attr.attnum = cons.conkey[1]
          WHERE cons.contype = 'p'
1197
            AND dep.refobjid = '#{table}'::regclass
1198 1199 1200
        end_sql

        row && row.first
1201 1202
      end

1203
      # Renames a table.
1204 1205 1206
      #
      # Example:
      #   rename_table('octopuses', 'octopi')
1207
      def rename_table(name, new_name)
1208
        clear_cache!
1209
        execute "ALTER TABLE #{quote_table_name(name)} RENAME TO #{quote_table_name(new_name)}"
1210
      end
1211

1212 1213
      # Adds a new column to the named table.
      # See TableDefinition#column for details of the options you can use.
S
Scott Barron 已提交
1214
      def add_column(table_name, column_name, type, options = {})
1215
        clear_cache!
1216 1217
        add_column_sql = "ALTER TABLE #{quote_table_name(table_name)} ADD COLUMN #{quote_column_name(column_name)} #{type_to_sql(type, options[:limit], options[:precision], options[:scale])}"
        add_column_options!(add_column_sql, options)
1218

1219
        execute add_column_sql
S
Scott Barron 已提交
1220
      end
D
Initial  
David Heinemeier Hansson 已提交
1221

1222 1223
      # Changes the column of a table.
      def change_column(table_name, column_name, type, options = {})
1224
        clear_cache!
1225 1226
        quoted_table_name = quote_table_name(table_name)

1227
        execute "ALTER TABLE #{quoted_table_name} ALTER COLUMN #{quote_column_name(column_name)} TYPE #{type_to_sql(type, options[:limit], options[:precision], options[:scale])}"
1228

1229 1230
        change_column_default(table_name, column_name, options[:default]) if options_include_default?(options)
        change_column_null(table_name, column_name, options[:null], options[:default]) if options.key?(:null)
1231
      end
1232

1233 1234
      # Changes the default value of a table column.
      def change_column_default(table_name, column_name, default)
1235
        clear_cache!
1236
        execute "ALTER TABLE #{quote_table_name(table_name)} ALTER COLUMN #{quote_column_name(column_name)} SET DEFAULT #{quote(default)}"
1237
      end
1238

1239
      def change_column_null(table_name, column_name, null, default = nil)
1240
        clear_cache!
1241
        unless null || default.nil?
1242
          execute("UPDATE #{quote_table_name(table_name)} SET #{quote_column_name(column_name)}=#{quote(default)} WHERE #{quote_column_name(column_name)} IS NULL")
1243
        end
1244
        execute("ALTER TABLE #{quote_table_name(table_name)} ALTER #{quote_column_name(column_name)} #{null ? 'DROP' : 'SET'} NOT NULL")
1245 1246
      end

1247 1248
      # Renames a column in a table.
      def rename_column(table_name, column_name, new_column_name)
1249
        clear_cache!
1250
        execute "ALTER TABLE #{quote_table_name(table_name)} RENAME COLUMN #{quote_column_name(column_name)} TO #{quote_column_name(new_column_name)}"
1251
      end
1252

1253 1254 1255 1256
      def remove_index!(table_name, index_name) #:nodoc:
        execute "DROP INDEX #{quote_table_name(index_name)}"
      end

1257 1258 1259 1260
      def rename_index(table_name, old_name, new_name)
        execute "ALTER INDEX #{quote_column_name(old_name)} RENAME TO #{quote_table_name(new_name)}"
      end

1261 1262
      def index_name_length
        63
1263
      end
1264

1265 1266
      # Maps logical Rails types to PostgreSQL-specific data types.
      def type_to_sql(type, limit = nil, precision = nil, scale = nil)
1267 1268 1269 1270 1271 1272 1273 1274 1275 1276
        case type.to_s
        when 'binary'
          # PostgreSQL doesn't support limits on binary (bytea) columns.
          # The hard limit is 1Gb, because of a 32-bit size field, and TOAST.
          case limit
          when nil, 0..0x3fffffff; super(type)
          else raise(ActiveRecordError, "No binary type has byte size #{limit}.")
          end
        when 'integer'
          return 'integer' unless limit
1277

1278 1279 1280 1281 1282 1283 1284 1285
          case limit
            when 1, 2; 'smallint'
            when 3, 4; 'integer'
            when 5..8; 'bigint'
            else raise(ActiveRecordError, "No integer type has byte size #{limit}. Use a numeric with precision 0 instead.")
          end
        else
          super
1286 1287
        end
      end
1288

1289
      # Returns a SELECT DISTINCT clause for a given set of columns and a given ORDER BY clause.
1290 1291 1292
      #
      # PostgreSQL requires the ORDER BY columns in the select list for distinct queries, and
      # requires that the ORDER BY include the distinct column.
1293
      #
1294
      #   distinct("posts.id", "posts.created_at desc")
1295 1296
      def distinct(columns, orders) #:nodoc:
        return "DISTINCT #{columns}" if orders.empty?
1297

1298 1299
        # Construct a clean list of column names from the ORDER BY clause, removing
        # any ASC/DESC modifiers
1300 1301 1302 1303
        order_columns = orders.collect do |s|
          s = s.to_sql unless s.is_a?(String)
          s.gsub(/\s+(ASC|DESC)\s*(NULLS\s+(FIRST|LAST)\s*)?/i, '')
        end
1304
        order_columns.delete_if { |c| c.blank? }
1305
        order_columns = order_columns.zip((0...order_columns.size).to_a).map { |s,i| "#{s} AS alias_#{i}" }
1306

1307
        "DISTINCT #{columns}, #{order_columns * ', '}"
1308
      end
1309

1310
      module Utils
1311 1312
        extend self

1313 1314 1315 1316 1317 1318 1319 1320 1321 1322
        # Returns an array of <tt>[schema_name, table_name]</tt> extracted from +name+.
        # +schema_name+ is nil if not specified in +name+.
        # +schema_name+ and +table_name+ exclude surrounding quotes (regardless of whether provided in +name+)
        # +name+ supports the range of schema/table references understood by PostgreSQL, for example:
        #
        # * <tt>table_name</tt>
        # * <tt>"table.name"</tt>
        # * <tt>schema_name.table_name</tt>
        # * <tt>schema_name."table.name"</tt>
        # * <tt>"schema.name"."table name"</tt>
1323
        def extract_schema_and_table(name)
1324 1325 1326 1327 1328
          table, schema = name.scan(/[^".\s]+|"[^"]*"/)[0..1].collect{|m| m.gsub(/(^"|"$)/,'') }.reverse
          [schema, table]
        end
      end

1329 1330
      def use_insert_returning?
        @use_insert_returning
1331 1332
      end

1333
      protected
1334
        # Returns the version of the connected PostgreSQL server.
1335
        def postgresql_version
1336
          @connection.server_version
1337 1338
        end

1339 1340 1341 1342
        # See http://www.postgresql.org/docs/9.1/static/errcodes-appendix.html
        FOREIGN_KEY_VIOLATION = "23503"
        UNIQUE_VIOLATION      = "23505"

1343
        def translate_exception(exception, message)
1344 1345
          case exception.result.error_field(PGresult::PG_DIAG_SQLSTATE)
          when UNIQUE_VIOLATION
1346
            RecordNotUnique.new(message, exception)
1347
          when FOREIGN_KEY_VIOLATION
1348
            InvalidForeignKey.new(message, exception)
1349 1350 1351 1352 1353
          else
            super
          end
        end

D
Initial  
David Heinemeier Hansson 已提交
1354
      private
1355 1356 1357 1358 1359 1360 1361 1362 1363 1364 1365 1366 1367 1368 1369 1370
      def initialize_type_map
        result = execute('SELECT oid, typname, typelem, typdelim FROM pg_type', 'SCHEMA')
        leaves, nodes = result.partition { |row| row['typelem'] == '0' }

        # populate the leaf nodes
        leaves.find_all { |row| OID.registered_type? row['typname'] }.each do |row|
          OID::TYPE_MAP[row['oid'].to_i] = OID::NAMES[row['typname']]
        end

        # populate composite types
        nodes.find_all { |row| OID::TYPE_MAP.key? row['typelem'].to_i }.each do |row|
          vector = OID::Vector.new row['typdelim'], OID::TYPE_MAP[row['typelem'].to_i]
          OID::TYPE_MAP[row['oid'].to_i] = vector
        end
      end

1371 1372
        FEATURE_NOT_SUPPORTED = "0A000" # :nodoc:

1373 1374
        def exec_no_cache(sql, binds)
          @connection.async_exec(sql)
1375
        end
1376

1377
        def exec_cache(sql, binds)
1378 1379 1380 1381 1382 1383 1384 1385 1386 1387 1388 1389 1390 1391 1392
          begin
            stmt_key = prepare_statement sql

            # Clear the queue
            @connection.get_last_result
            @connection.send_query_prepared(stmt_key, binds.map { |col, val|
              type_cast(val, col)
            })
            @connection.block
            @connection.get_last_result
          rescue PGError => e
            # Get the PG code for the failure.  Annoyingly, the code for
            # prepared statements whose return value may have changed is
            # FEATURE_NOT_SUPPORTED.  Check here for more details:
            # http://git.postgresql.org/gitweb/?p=postgresql.git;a=blob;f=src/backend/utils/cache/plancache.c#l573
1393 1394 1395 1396 1397
            begin
              code = e.result.result_error_field(PGresult::PG_DIAG_SQLSTATE)
            rescue
              raise e
            end
1398 1399 1400 1401 1402 1403 1404 1405 1406 1407 1408 1409 1410 1411 1412 1413 1414 1415 1416
            if FEATURE_NOT_SUPPORTED == code
              @statements.delete sql_key(sql)
              retry
            else
              raise e
            end
          end
        end

        # Returns the statement identifier for the client side cache
        # of statements
        def sql_key(sql)
          "#{schema_search_path}-#{sql}"
        end

        # Prepare the statement if it hasn't been prepared, return
        # the statement key.
        def prepare_statement(sql)
          sql_key = sql_key(sql)
1417
          unless @statements.key? sql_key
1418
            nextkey = @statements.next_key
1419
            @connection.prepare nextkey, sql
1420
            @statements[sql_key] = nextkey
1421
          end
1422
          @statements[sql_key]
1423
        end
1424

P
Pratik Naik 已提交
1425
        # The internal PostgreSQL identifier of the money data type.
1426
        MONEY_COLUMN_TYPE_OID = 790 #:nodoc:
1427 1428
        # The internal PostgreSQL identifier of the BYTEA data type.
        BYTEA_COLUMN_TYPE_OID = 17 #:nodoc:
1429 1430 1431 1432

        # Connects to a PostgreSQL server and sets up the adapter depending on the
        # connected server's characteristics.
        def connect
1433
          @connection = PGconn.connect(@connection_parameters)
1434 1435 1436 1437

          # Money type has a fixed precision of 10 in PostgreSQL 8.2 and below, and as of
          # PostgreSQL 8.3 it has a fixed precision of 19. PostgreSQLColumn.extract_precision
          # should know about this but can't detect it there, so deal with it here.
1438 1439
          PostgreSQLColumn.money_precision = (postgresql_version >= 80300) ? 19 : 10

1440 1441 1442
          configure_connection
        end

1443
        # Configures the encoding, verbosity, schema search path, and time zone of the connection.
1444
        # This is called by #connect and should not be called manually.
1445 1446
        def configure_connection
          if @config[:encoding]
1447
            @connection.set_client_encoding(@config[:encoding])
1448
          end
1449 1450
          self.client_min_messages = @config[:min_messages] if @config[:min_messages]
          self.schema_search_path = @config[:schema_search_path] || @config[:schema_order]
1451 1452 1453 1454

          # Use standard-conforming strings if available so we don't have to do the E'...' dance.
          set_standard_conforming_strings

1455
          # If using Active Record's time zone support configure the connection to return
1456
          # TIMESTAMP WITH ZONE types in UTC.
1457
          if ActiveRecord::Base.default_timezone == :utc
1458
            execute("SET time zone 'UTC'", 'SCHEMA')
1459
          elsif @local_tz
1460
            execute("SET time zone '#{@local_tz}'", 'SCHEMA')
1461
          end
1462 1463
        end

1464
        # Returns the current ID of a table's sequence.
1465
        def last_insert_id(sequence_name) #:nodoc:
1466 1467 1468
          Integer(last_insert_id_value(sequence_name))
        end

D
Doug Cole 已提交
1469 1470 1471 1472 1473
        def last_insert_id_value(sequence_name)
          last_insert_id_result(sequence_name).rows.first.first
        end

        def last_insert_id_result(sequence_name) #:nodoc:
1474
          exec_query("SELECT currval('#{sequence_name}')", 'SQL')
D
Initial  
David Heinemeier Hansson 已提交
1475 1476
        end

1477
        # Executes a SELECT query and returns the results, performing any data type
1478
        # conversions that are required to be performed here instead of in PostgreSQLColumn.
1479
        def select(sql, name = nil, binds = [])
1480
          exec_query(sql, name, binds)
1481 1482 1483
        end

        def select_raw(sql, name = nil)
1484
          res = execute(sql, name)
1485
          results = result_as_array(res)
1486
          fields = res.fields
1487
          res.clear
1488
          return fields, results
M
Marcel Molina 已提交
1489 1490
        end

1491
        # Returns the list of a table's column names, data types, and default values.
1492 1493 1494 1495 1496 1497 1498 1499 1500 1501 1502 1503 1504 1505 1506 1507 1508
        #
        # The underlying query is roughly:
        #  SELECT column.name, column.type, default.value
        #    FROM column LEFT JOIN default
        #      ON column.table_id = default.table_id
        #     AND column.num = default.column_num
        #   WHERE column.table_id = get_table_id('table_name')
        #     AND column.num > 0
        #     AND NOT column.is_dropped
        #   ORDER BY column.num
        #
        # If the table name is not prefixed with a schema, the database will
        # take the first match from the schema search path.
        #
        # Query implementation notes:
        #  - format_type includes the column size constraint, e.g. varchar(50)
        #  - ::regclass is a function that gives the id for a table name
1509
        def column_definitions(table_name) #:nodoc:
1510
          exec_query(<<-end_sql, 'SCHEMA').rows
1511
            SELECT a.attname, format_type(a.atttypid, a.atttypmod), d.adsrc, a.attnotnull, a.atttypid, a.atttypmod
1512 1513
              FROM pg_attribute a LEFT JOIN pg_attrdef d
                ON a.attrelid = d.adrelid AND a.attnum = d.adnum
1514
             WHERE a.attrelid = '#{quote_table_name(table_name)}'::regclass
1515 1516 1517
               AND a.attnum > 0 AND NOT a.attisdropped
             ORDER BY a.attnum
          end_sql
D
Initial  
David Heinemeier Hansson 已提交
1518
        end
1519 1520

        def extract_pg_identifier_from_name(name)
1521
          match_data = name.start_with?('"') ? name.match(/\"([^\"]+)\"/) : name.match(/([^\.]+)/)
1522 1523

          if match_data
1524 1525
            rest = name[match_data[0].length, name.length]
            rest = rest[1, rest.length] if rest.start_with? "."
J
José Valim 已提交
1526
            [match_data[1], (rest.length > 0 ? rest : nil)]
1527 1528
          end
        end
1529

1530 1531 1532 1533 1534
        def extract_table_ref_from_insert_sql(sql)
          sql[/into\s+([^\(]*).*values\s*\(/i]
          $1.strip if $1
        end

1535 1536 1537
        def table_definition
          TableDefinition.new(self)
        end
D
Initial  
David Heinemeier Hansson 已提交
1538 1539 1540
    end
  end
end