postgresql_adapter.rb 51.5 KB
Newer Older
D
Initial  
David Heinemeier Hansson 已提交
1
require 'active_record/connection_adapters/abstract_adapter'
2
require 'active_support/core_ext/object/blank'
3
require 'active_record/connection_adapters/statement_pool'
4
require 'active_record/connection_adapters/postgresql/oid'
5
require 'arel/visitors/bind_visitor'
6 7 8

# Make sure we're using pg high enough for PGResult#values
gem 'pg', '~> 0.11'
9
require 'pg'
D
Initial  
David Heinemeier Hansson 已提交
10

D
Dan McClain 已提交
11 12
require 'ipaddr'

D
Initial  
David Heinemeier Hansson 已提交
13
module ActiveRecord
14
  module ConnectionHandling
D
Initial  
David Heinemeier Hansson 已提交
15
    # Establishes a connection to the database that's used by all Active Record objects
J
Jon Leighton 已提交
16
    def postgresql_connection(config) # :nodoc:
17
      conn_params = config.symbolize_keys
D
Initial  
David Heinemeier Hansson 已提交
18

19
      # Forward any unused config params to PGconn.connect.
20
      [:statement_limit, :encoding, :min_messages, :schema_search_path,
21
       :schema_order, :adapter, :pool, :checkout_timeout, :template,
22
       :reaping_frequency, :insert_returning].each do |key|
23 24
        conn_params.delete key
      end
25
      conn_params.delete_if { |k,v| v.nil? }
26 27 28 29

      # Map ActiveRecords param names to PGs.
      conn_params[:user] = conn_params.delete(:username) if conn_params[:username]
      conn_params[:dbname] = conn_params.delete(:database) if conn_params[:database]
D
Initial  
David Heinemeier Hansson 已提交
30

31
      # The postgres drivers don't allow the creation of an unconnected PGconn object,
32
      # so just pass a nil connection object for the time being.
33
      ConnectionAdapters::PostgreSQLAdapter.new(nil, logger, conn_params, config)
34 35
    end
  end
36

37 38 39 40
  module ConnectionAdapters
    # PostgreSQL-specific extensions to column definitions in a table.
    class PostgreSQLColumn < Column #:nodoc:
      # Instantiates a new PostgreSQL column definition in a table.
41 42
      def initialize(name, default, oid_type, sql_type = nil, null = true)
        @oid_type = oid_type
43 44
        super(name, self.class.extract_value_from_default(default), sql_type, null)
      end
45

46 47 48
      # :stopdoc:
      class << self
        attr_accessor :money_precision
49 50 51 52 53 54 55 56 57 58
        def string_to_time(string)
          return string unless String === string

          case string
          when 'infinity'  then 1.0 / 0.0
          when '-infinity' then -1.0 / 0.0
          else
            super
          end
        end
59

60
        def hstore_to_string(object)
61
          if Hash === object
A
Aaron Patterson 已提交
62 63
            object.map { |k,v|
              "#{escape_hstore(k)}=>#{escape_hstore(v)}"
64
            }.join ','
65
          else
66
            object
A
Aaron Patterson 已提交
67 68 69
          end
        end

70 71 72 73 74 75 76 77 78 79 80
        def string_to_hstore(string)
          if string.nil?
            nil
          elsif String === string
            Hash[string.scan(HstorePair).map { |k,v|
              v = v.upcase == 'NULL' ? nil : v.gsub(/^"(.*)"$/,'\1').gsub(/\\(.)/, '\1')
              k = k.gsub(/^"(.*)"$/,'\1').gsub(/\\(.)/, '\1')
              [k,v]
            }]
          else
            string
A
Aaron Patterson 已提交
81 82 83
          end
        end

84 85 86 87
        def string_to_cidr(string)
          if string.nil?
            nil
          elsif String === string
D
Dan McClain 已提交
88
            IPAddr.new(string)
89 90 91 92 93 94 95
          else
            string
          end

        end

        def cidr_to_string(object)
D
Dan McClain 已提交
96 97
          if IPAddr === object
            "#{object.to_s}/#{object.instance_variable_get(:@mask_addr).to_s(2).count('1')}"
98 99 100 101 102
          else
            object
          end
        end

103
        private
104 105 106 107
        HstorePair = begin
          quoted_string = /"[^"\\]*(?:\\.[^"\\]*)*"/
          unquoted_string = /(?:\\.|[^\s,])[^\s=,\\]*(?:\\.[^\s=,\\]*|=[^,>])*/
          /(#{quoted_string}|#{unquoted_string})\s*=>\s*(#{quoted_string}|#{unquoted_string})/
A
Aaron Patterson 已提交
108 109 110
        end

        def escape_hstore(value)
111 112
            value.nil?         ? 'NULL'
          : value == ""        ? '""'
R
Ryan Fitzgerald 已提交
113
          :                      '"%s"' % value.to_s.gsub(/(["\\])/, '\\\\\1')
114
        end
115 116 117
      end
      # :startdoc:

118 119 120 121 122 123 124 125 126 127 128 129 130 131
      # Extracts the value from a PostgreSQL column default definition.
      def self.extract_value_from_default(default)
        # This is a performance optimization for Ruby 1.9.2 in development.
        # If the value is nil, we return nil straight away without checking
        # the regular expressions. If we check each regular expression,
        # Regexp#=== will call NilClass#to_str, which will trigger
        # method_missing (defined by whiny nil in ActiveSupport) which
        # makes this method very very slow.
        return default unless default

        case default
          # Numeric types
          when /\A\(?(-?\d+(\.\d*)?\)?)\z/
            $1
132
          # Character types
133 134 135 136 137
          when /\A'(.*)'::(?:character varying|bpchar|text)\z/m
            $1
          # Character types (8.1 formatting)
          when /\AE'(.*)'::(?:character varying|bpchar|text)\z/m
            $1.gsub(/\\(\d\d\d)/) { $1.oct.chr }
138
          # Binary data types
139 140
          when /\A'(.*)'::bytea\z/m
            $1
141
          # Date/time types
142 143 144 145 146 147 148 149 150
          when /\A'(.+)'::(?:time(?:stamp)? with(?:out)? time zone|date)\z/
            $1
          when /\A'(.*)'::interval\z/
            $1
          # Boolean type
          when 'true'
            true
          when 'false'
            false
151
          # Geometric types
152 153
          when /\A'(.*)'::(?:point|line|lseg|box|"?path"?|polygon|circle)\z/
            $1
154
          # Network address types
155 156 157 158 159
          when /\A'(.*)'::(?:cidr|inet|macaddr)\z/
            $1
          # Bit string types
          when /\AB'(.*)'::"?bit(?: varying)?"?\z/
            $1
160
          # XML type
161 162
          when /\A'(.*)'::xml\z/m
            $1
163
          # Arrays
164 165
          when /\A'(.*)'::"?\D+"?\[\]\z/
            $1
166 167 168
          # Hstore
          when /\A'(.*)'::hstore\z/
            $1
169
          # Object identifier types
170 171
          when /\A-?\d+\z/
            $1
172
          else
173 174 175
            # Anything else is blank, some user type, or some function
            # and we can't know the value of that, so return nil.
            nil
176
        end
177
      end
178

179 180 181 182 183 184 185
      def type_cast(value)
        return if value.nil?
        return super if encoded?

        @oid_type.type_cast value
      end

186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220
      private
      def extract_limit(sql_type)
        case sql_type
        when /^bigint/i;    8
        when /^smallint/i;  2
        else super
        end
      end

      # Extracts the scale from PostgreSQL-specific data types.
      def extract_scale(sql_type)
        # Money type has a fixed scale of 2.
        sql_type =~ /^money/ ? 2 : super
      end

      # Extracts the precision from PostgreSQL-specific data types.
      def extract_precision(sql_type)
        if sql_type == 'money'
          self.class.money_precision
        else
          super
        end
      end

      # Maps PostgreSQL-specific data types to logical Rails types.
      def simplified_type(field_type)
        case field_type
        # Numeric and monetary types
        when /^(?:real|double precision)$/
          :float
        # Monetary types
        when 'money'
          :decimal
        when 'hstore'
          :hstore
221 222 223 224 225 226 227
        # Network address types
        when 'inet'
          :inet
        when 'cidr'
          :cidr
        when 'macaddr'
          :macaddr
228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265
        # Character types
        when /^(?:character varying|bpchar)(?:\(\d+\))?$/
          :string
        # Binary data types
        when 'bytea'
          :binary
        # Date/time types
        when /^timestamp with(?:out)? time zone$/
          :datetime
        when 'interval'
          :string
        # Geometric types
        when /^(?:point|line|lseg|box|"?path"?|polygon|circle)$/
          :string
        # Bit strings
        when /^bit(?: varying)?(?:\(\d+\))?$/
          :string
        # XML type
        when 'xml'
          :xml
        # tsvector type
        when 'tsvector'
          :tsvector
        # Arrays
        when /^\D+\[\]$/
          :string
        # Object identifier types
        when 'oid'
          :integer
        # UUID type
        when 'uuid'
          :string
        # Small and big integer types
        when /^(?:small|big)int$/
          :integer
        # Pass through all types that are not specific to PostgreSQL.
        else
          super
266
        end
267
      end
D
Initial  
David Heinemeier Hansson 已提交
268 269
    end

270
    # The PostgreSQL adapter works with the native C (https://bitbucket.org/ged/ruby-pg) driver.
271 272 273
    #
    # Options:
    #
274 275
    # * <tt>:host</tt> - Defaults to a Unix-domain socket in /tmp. On machines without Unix-domain sockets,
    #   the default is to connect to localhost.
P
Pratik Naik 已提交
276
    # * <tt>:port</tt> - Defaults to 5432.
277 278 279
    # * <tt>:username</tt> - Defaults to be the same as the operating system name of the user running the application.
    # * <tt>:password</tt> - Password to be used if the server demands password authentication.
    # * <tt>:database</tt> - Defaults to be the same as the user name.
280
    # * <tt>:schema_search_path</tt> - An optional schema search path for the connection given
281
    #   as a string of comma-separated schema names. This is backward-compatible with the <tt>:schema_order</tt> option.
282
    # * <tt>:encoding</tt> - An optional client encoding that is used in a <tt>SET client_encoding TO
283
    #   <encoding></tt> call on the connection.
284
    # * <tt>:min_messages</tt> - An optional client min messages that is used in a
285
    #   <tt>SET client_min_messages TO <min_messages></tt> call on the connection.
286 287
    # * <tt>:insert_returning</tt> - An optional boolean to control the use or <tt>RETURNING</tt> for <tt>INSERT<tt> statements
    #   defaults to true.
288 289 290 291 292 293 294
    #
    # Any further options are used as connection parameters to libpq. See
    # http://www.postgresql.org/docs/9.1/static/libpq-connect.html for the
    # list of parameters.
    #
    # In addition, default connection parameters of libpq can be set per environment variables.
    # See http://www.postgresql.org/docs/9.1/static/libpq-envars.html .
295
    class PostgreSQLAdapter < AbstractAdapter
296 297 298 299 300
      class TableDefinition < ActiveRecord::ConnectionAdapters::TableDefinition
        def xml(*args)
          options = args.extract_options!
          column(args[0], 'xml', options)
        end
301 302 303 304 305

        def tsvector(*args)
          options = args.extract_options!
          column(args[0], 'tsvector', options)
        end
306 307 308 309

        def hstore(name, options = {})
          column(name, 'hstore', options)
        end
310 311 312 313 314 315 316 317 318 319 320 321

        def inet(name, options = {})
          column(name, 'inet', options)
        end

        def cidr(name, options = {})
          column(name, 'cidr', options)
        end

        def macaddr(name, options = {})
          column(name, 'macaddr', options)
        end
322 323
      end

324
      ADAPTER_NAME = 'PostgreSQL'
325 326

      NATIVE_DATABASE_TYPES = {
327
        :primary_key => "serial primary key",
328 329 330 331 332 333 334 335 336 337
        :string      => { :name => "character varying", :limit => 255 },
        :text        => { :name => "text" },
        :integer     => { :name => "integer" },
        :float       => { :name => "float" },
        :decimal     => { :name => "decimal" },
        :datetime    => { :name => "timestamp" },
        :timestamp   => { :name => "timestamp" },
        :time        => { :name => "time" },
        :date        => { :name => "date" },
        :binary      => { :name => "bytea" },
338
        :boolean     => { :name => "boolean" },
339
        :xml         => { :name => "xml" },
340
        :tsvector    => { :name => "tsvector" },
341 342 343 344
        :hstore      => { :name => "hstore" },
        :inet        => { :name => "inet" },
        :cidr        => { :name => "cidr" },
        :macaddr     => { :name => "macaddr" }
345 346
      }

347
      # Returns 'PostgreSQL' as adapter name for identification purposes.
348
      def adapter_name
349
        ADAPTER_NAME
350 351
      end

352 353
      # Returns +true+, since this connection adapter supports prepared statement
      # caching.
354 355 356 357
      def supports_statement_cache?
        true
      end

358 359 360 361
      def supports_index_sort_order?
        true
      end

362 363 364 365
      def supports_partial_index?
        true
      end

366 367 368 369
      class StatementPool < ConnectionAdapters::StatementPool
        def initialize(connection, max)
          super
          @counter = 0
370
          @cache   = Hash.new { |h,pid| h[pid] = {} }
371 372
        end

373 374 375 376
        def each(&block); cache.each(&block); end
        def key?(key);    cache.key?(key); end
        def [](key);      cache[key]; end
        def length;       cache.length; end
377 378 379 380 381 382

        def next_key
          "a#{@counter + 1}"
        end

        def []=(sql, key)
383 384
          while @max <= cache.size
            dealloc(cache.shift.last)
385 386
          end
          @counter += 1
387
          cache[sql] = key
388 389 390
        end

        def clear
391
          cache.each_value do |stmt_key|
392 393
            dealloc stmt_key
          end
394
          cache.clear
395 396
        end

397 398 399 400 401
        def delete(sql_key)
          dealloc cache[sql_key]
          cache.delete sql_key
        end

402
        private
403
        def cache
A
Aaron Patterson 已提交
404
          @cache[Process.pid]
405 406
        end

407
        def dealloc(key)
408 409 410 411 412 413 414
          @connection.query "DEALLOCATE #{key}" if connection_active?
        end

        def connection_active?
          @connection.status == PGconn::CONNECTION_OK
        rescue PGError
          false
415 416 417
        end
      end

418 419 420 421
      class BindSubstitution < Arel::Visitors::PostgreSQL # :nodoc:
        include Arel::Visitors::BindVisitor
      end

422 423
      # Initializes and connects a PostgreSQL adapter.
      def initialize(connection, logger, connection_parameters, config)
424
        super(connection, logger)
425 426 427 428 429 430 431 432 433

        if config.fetch(:prepared_statements) { true }
          @visitor = Arel::Visitors::PostgreSQL.new self
        else
          @visitor = BindSubstitution.new self
        end

        connection_parameters.delete :prepared_statements

434
        @connection_parameters, @config = connection_parameters, config
435

436 437
        # @local_tz is initialized as nil to avoid warnings when connect tries to use it
        @local_tz = nil
438 439
        @table_alias_length = nil

440
        connect
441 442
        @statements = StatementPool.new @connection,
                                        config.fetch(:statement_limit) { 1000 }
443 444 445 446 447

        if postgresql_version < 80200
          raise "Your version of PostgreSQL (#{postgresql_version}) is too old, please upgrade!"
        end

448
        initialize_type_map
449
        @local_tz = execute('SHOW TIME ZONE', 'SCHEMA').first["TimeZone"]
450
        @use_insert_returning = @config.key?(:insert_returning) ? @config[:insert_returning] : true
451 452
      end

X
Xavier Noria 已提交
453
      # Clears the prepared statements cache.
454 455 456 457
      def clear_cache!
        @statements.clear
      end

458 459
      # Is this connection alive and ready for queries?
      def active?
460 461
        @connection.status == PGconn::CONNECTION_OK
      rescue PGError
462
        false
463 464 465 466
      end

      # Close then reopen the connection.
      def reconnect!
467 468 469
        clear_cache!
        @connection.reset
        configure_connection
470
      end
471

472 473 474 475 476
      def reset!
        clear_cache!
        super
      end

477 478
      # Disconnects from the database if already connected. Otherwise, this
      # method does nothing.
479
      def disconnect!
480
        clear_cache!
481 482
        @connection.close rescue nil
      end
483

484
      def native_database_types #:nodoc:
485
        NATIVE_DATABASE_TYPES
486
      end
487

488
      # Returns true, since this connection adapter supports migrations.
489 490
      def supports_migrations?
        true
491 492
      end

493
      # Does PostgreSQL support finding primary key on non-Active Record tables?
494 495 496 497
      def supports_primary_key? #:nodoc:
        true
      end

498 499 500
      # Enable standard-conforming strings if available.
      def set_standard_conforming_strings
        old, self.client_min_messages = client_min_messages, 'panic'
501
        execute('SET standard_conforming_strings = on', 'SCHEMA') rescue nil
502 503
      ensure
        self.client_min_messages = old
504 505
      end

506
      def supports_insert_with_returning?
507
        true
508 509
      end

510 511 512
      def supports_ddl_transactions?
        true
      end
513

514
      # Returns true, since this connection adapter supports savepoints.
515 516 517
      def supports_savepoints?
        true
      end
518

519 520 521 522 523
      # Returns true.
      def supports_explain?
        true
      end

524
      # Returns the configured supported identifier length supported by PostgreSQL
525
      def table_alias_length
526
        @table_alias_length ||= query('SHOW max_identifier_length')[0][0].to_i
527
      end
528

529 530
      # QUOTING ==================================================

531
      # Escapes binary strings for bytea input to the database.
532
      def escape_bytea(value)
533
        PGconn.escape_bytea(value) if value
534 535 536 537 538
      end

      # Unescapes bytea output from a database to the binary string it represents.
      # NOTE: This is NOT an inverse of escape_bytea! This is only to be used
      #       on escaped binary output from database drive.
539
      def unescape_bytea(value)
540
        PGconn.unescape_bytea(value) if value
541 542
      end

543 544
      # Quotes PostgreSQL-specific data types for SQL input.
      def quote(value, column = nil) #:nodoc:
545 546
        return super unless column

A
Aaron Patterson 已提交
547
        case value
548 549 550 551 552
        when Hash
          case column.sql_type
          when 'hstore' then super(PostgreSQLColumn.hstore_to_string(value), column)
          else super
          end
D
Dan McClain 已提交
553
        when IPAddr
554 555 556 557
          case column.sql_type
          when 'inet', 'cidr' then super(PostgreSQLColumn.cidr_to_string(value), column)
          else super
          end
558
        when Float
559 560 561 562 563 564 565
          if value.infinite? && column.type == :datetime
            "'#{value.to_s.downcase}'"
          elsif value.infinite? || value.nan?
            "'#{value.to_s}'"
          else
            super
          end
A
Aaron Patterson 已提交
566 567
        when Numeric
          return super unless column.sql_type == 'money'
568
          # Not truly string input, so doesn't require (or allow) escape string syntax.
569
          "'#{value}'"
A
Aaron Patterson 已提交
570 571 572 573 574 575 576 577 578 579 580
        when String
          case column.sql_type
          when 'bytea' then "'#{escape_bytea(value)}'"
          when 'xml'   then "xml '#{quote_string(value)}'"
          when /^bit/
            case value
            when /^[01]*$/      then "B'#{value}'" # Bit-string notation
            when /^[0-9A-F]*$/i then "X'#{value}'" # Hexadecimal notation
            end
          else
            super
581
          end
582 583 584 585 586
        else
          super
        end
      end

587 588 589 590 591 592
      def type_cast(value, column)
        return super unless column

        case value
        when String
          return super unless 'bytea' == column.sql_type
593
          { :value => value, :format => 1 }
594 595 596
        when Hash
          return super unless 'hstore' == column.sql_type
          PostgreSQLColumn.hstore_to_string(value)
D
Dan McClain 已提交
597
        when IPAddr
598 599
          return super unless ['inet','cidr'].includes? column.sql_type
          PostgreSQLColumn.cidr_to_string(value)
600 601 602 603 604
        else
          super
        end
      end

605 606 607
      # Quotes strings for use in SQL input.
      def quote_string(s) #:nodoc:
        @connection.escape(s)
608 609
      end

610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628
      # Checks the following cases:
      #
      # - table_name
      # - "table.name"
      # - schema_name.table_name
      # - schema_name."table.name"
      # - "schema.name".table_name
      # - "schema.name"."table.name"
      def quote_table_name(name)
        schema, name_part = extract_pg_identifier_from_name(name.to_s)

        unless name_part
          quote_column_name(schema)
        else
          table_name, name_part = extract_pg_identifier_from_name(name_part)
          "#{quote_column_name(schema)}.#{quote_column_name(table_name)}"
        end
      end

629 630
      # Quotes column names for use in SQL queries.
      def quote_column_name(name) #:nodoc:
631
        PGconn.quote_ident(name.to_s)
632 633
      end

634 635 636
      # Quote date/time values for use in SQL input. Includes microseconds
      # if the value is a Time responding to usec.
      def quoted_date(value) #:nodoc:
637 638 639 640 641
        if value.acts_like?(:time) && value.respond_to?(:usec)
          "#{super}.#{sprintf("%06d", value.usec)}"
        else
          super
        end
642 643
      end

644 645
      # Set the authorized user for this session
      def session_auth=(user)
646
        clear_cache!
A
Aaron Patterson 已提交
647
        exec_query "SET SESSION AUTHORIZATION #{user}"
648 649
      end

650 651
      # REFERENTIAL INTEGRITY ====================================

652
      def supports_disable_referential_integrity? #:nodoc:
653
        true
654 655
      end

656
      def disable_referential_integrity #:nodoc:
657
        if supports_disable_referential_integrity? then
658 659
          execute(tables.collect { |name| "ALTER TABLE #{quote_table_name(name)} DISABLE TRIGGER ALL" }.join(";"))
        end
660 661
        yield
      ensure
662
        if supports_disable_referential_integrity? then
663 664
          execute(tables.collect { |name| "ALTER TABLE #{quote_table_name(name)} ENABLE TRIGGER ALL" }.join(";"))
        end
665
      end
666 667 668

      # DATABASE STATEMENTS ======================================

669
      def explain(arel, binds = [])
670
        sql = "EXPLAIN #{to_sql(arel, binds)}"
671
        ExplainPrettyPrinter.new.pp(exec_query(sql, 'EXPLAIN', binds))
X
Xavier Noria 已提交
672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710
      end

      class ExplainPrettyPrinter # :nodoc:
        # Pretty prints the result of a EXPLAIN in a way that resembles the output of the
        # PostgreSQL shell:
        #
        #                                     QUERY PLAN
        #   ------------------------------------------------------------------------------
        #    Nested Loop Left Join  (cost=0.00..37.24 rows=8 width=0)
        #      Join Filter: (posts.user_id = users.id)
        #      ->  Index Scan using users_pkey on users  (cost=0.00..8.27 rows=1 width=4)
        #            Index Cond: (id = 1)
        #      ->  Seq Scan on posts  (cost=0.00..28.88 rows=8 width=4)
        #            Filter: (posts.user_id = 1)
        #   (6 rows)
        #
        def pp(result)
          header = result.columns.first
          lines  = result.rows.map(&:first)

          # We add 2 because there's one char of padding at both sides, note
          # the extra hyphens in the example above.
          width = [header, *lines].map(&:length).max + 2

          pp = []

          pp << header.center(width).rstrip
          pp << '-' * width

          pp += lines.map {|line| " #{line}"}

          nrows = result.rows.length
          rows_label = nrows == 1 ? 'row' : 'rows'
          pp << "(#{nrows} #{rows_label})"

          pp.join("\n") + "\n"
        end
      end

711 712 713 714 715 716
      # Executes a SELECT query and returns an array of rows. Each row is an
      # array of field values.
      def select_rows(sql, name = nil)
        select_raw(sql, name).last
      end

717
      # Executes an INSERT query and returns the new record's ID
718
      def insert_sql(sql, name = nil, pk = nil, id_value = nil, sequence_name = nil)
719 720 721 722 723
        unless pk
          # Extract the table from the insert sql. Yuck.
          table_ref = extract_table_ref_from_insert_sql(sql)
          pk = primary_key(table_ref) if table_ref
        end
724

725
        if pk && use_insert_returning?
726
          select_value("#{sql} RETURNING #{quote_column_name(pk)}")
727 728 729
        elsif pk
          super
          last_insert_id_value(sequence_name || default_sequence_name(table_ref, pk))
730 731
        else
          super
732
        end
733
      end
734
      alias :create :insert
735

736 737
      # create a 2D array representing the result set
      def result_as_array(res) #:nodoc:
738
        # check if we have any binary column and if they need escaping
739 740
        ftypes = Array.new(res.nfields) do |i|
          [i, res.ftype(i)]
741 742
        end

743 744 745 746 747 748
        rows = res.values
        return rows unless ftypes.any? { |_, x|
          x == BYTEA_COLUMN_TYPE_OID || x == MONEY_COLUMN_TYPE_OID
        }

        typehash = ftypes.group_by { |_, type| type }
749 750
        binaries = typehash[BYTEA_COLUMN_TYPE_OID] || []
        monies   = typehash[MONEY_COLUMN_TYPE_OID] || []
751 752 753

        rows.each do |row|
          # unescape string passed BYTEA field (OID == 17)
754 755
          binaries.each do |index, _|
            row[index] = unescape_bytea(row[index])
756 757 758 759 760 761
          end

          # If this is a money type column and there are any currency symbols,
          # then strip them off. Indeed it would be prettier to do this in
          # PostgreSQLColumn.string_to_decimal but would break form input
          # fields that call value_before_type_cast.
762
          monies.each do |index, _|
763 764 765 766 767 768 769 770 771 772
            data = row[index]
            # Because money output is formatted according to the locale, there are two
            # cases to consider (note the decimal separators):
            #  (1) $12,345,678.12
            #  (2) $12.345.678,12
            case data
            when /^-?\D+[\d,]+\.\d{2}$/  # (1)
              data.gsub!(/[^-\d.]/, '')
            when /^-?\D+[\d.]+,\d{2}$/  # (2)
              data.gsub!(/[^-\d,]/, '').sub!(/,/, '.')
773
            end
774 775 776 777 778 779
          end
        end
      end


      # Queries the database and returns the results in an Array-like object
780
      def query(sql, name = nil) #:nodoc:
781
        log(sql, name) do
782
          result_as_array @connection.async_exec(sql)
783
        end
784 785
      end

786
      # Executes an SQL statement, returning a PGresult object on success
787 788
      # or raising a PGError exception otherwise.
      def execute(sql, name = nil)
789
        log(sql, name) do
790
          @connection.async_exec(sql)
791
        end
792 793
      end

794
      def substitute_at(column, index)
795
        Arel::Nodes::BindParam.new "$#{index + 1}"
796 797
      end

798 799 800 801 802 803 804
      class Result < ActiveRecord::Result
        def initialize(columns, rows, column_types)
          super(columns, rows)
          @column_types = column_types
        end
      end

A
Aaron Patterson 已提交
805
      def exec_query(sql, name = 'SQL', binds = [])
806
        log(sql, name, binds) do
807 808
          result = binds.empty? ? exec_no_cache(sql, binds) :
                                  exec_cache(sql, binds)
809

810 811 812
          types = {}
          result.fields.each_with_index do |fname, i|
            ftype = result.ftype i
813 814
            fmod  = result.fmod i
            types[fname] = OID::TYPE_MAP.fetch(ftype, fmod) { |oid, mod|
815 816 817 818 819
              warn "unknown OID: #{fname}(#{oid}) (#{sql})"
              OID::Identity.new
            }
          end

820
          ret = Result.new(result.fields, result.values, types)
821 822
          result.clear
          return ret
823 824 825
        end
      end

826 827 828 829 830 831 832 833 834
      def exec_delete(sql, name = 'SQL', binds = [])
        log(sql, name, binds) do
          result = binds.empty? ? exec_no_cache(sql, binds) :
                                  exec_cache(sql, binds)
          affected = result.cmd_tuples
          result.clear
          affected
        end
      end
835
      alias :exec_update :exec_delete
836

837 838
      def sql_for_insert(sql, pk, id_value, sequence_name, binds)
        unless pk
839 840 841
          # Extract the table from the insert sql. Yuck.
          table_ref = extract_table_ref_from_insert_sql(sql)
          pk = primary_key(table_ref) if table_ref
842 843
        end

844
        if pk && use_insert_returning?
845 846
          sql = "#{sql} RETURNING #{quote_column_name(pk)}"
        end
847 848 849 850

        [sql, binds]
      end

851 852
      def exec_insert(sql, name, binds, pk = nil, sequence_name = nil)
        val = exec_query(sql, name, binds)
853
        if !use_insert_returning? && pk
D
Doug Cole 已提交
854
          unless sequence_name
855 856 857 858
            table_ref = extract_table_ref_from_insert_sql(sql)
            sequence_name = default_sequence_name(table_ref, pk)
            return val unless sequence_name
          end
D
Doug Cole 已提交
859
          last_insert_id_result(sequence_name)
860 861 862 863 864
        else
          val
        end
      end

865
      # Executes an UPDATE query and returns the number of affected tuples.
866
      def update_sql(sql, name = nil)
867
        super.cmd_tuples
868 869
      end

870 871
      # Begins a transaction.
      def begin_db_transaction
872 873 874
        execute "BEGIN"
      end

875 876
      # Commits a transaction.
      def commit_db_transaction
877 878
        execute "COMMIT"
      end
879

880 881
      # Aborts a transaction.
      def rollback_db_transaction
882 883
        execute "ROLLBACK"
      end
884

885 886
      def outside_transaction?
        @connection.transaction_status == PGconn::PQTRANS_IDLE
887
      end
888

J
Jonathan Viney 已提交
889 890 891 892 893 894 895 896
      def create_savepoint
        execute("SAVEPOINT #{current_savepoint_name}")
      end

      def rollback_to_savepoint
        execute("ROLLBACK TO SAVEPOINT #{current_savepoint_name}")
      end

897
      def release_savepoint
J
Jonathan Viney 已提交
898 899
        execute("RELEASE SAVEPOINT #{current_savepoint_name}")
      end
900

901 902
      # SCHEMA STATEMENTS ========================================

903 904 905
      # Drops the database specified on the +name+ attribute
      # and creates it again using the provided +options+.
      def recreate_database(name, options = {}) #:nodoc:
906
        drop_database(name)
907
        create_database(name, options)
908 909
      end

910
      # Create a new PostgreSQL database. Options include <tt>:owner</tt>, <tt>:template</tt>,
911 912
      # <tt>:encoding</tt>, <tt>:tablespace</tt>, and <tt>:connection_limit</tt> (note that MySQL uses
      # <tt>:charset</tt> while PostgreSQL uses <tt>:encoding</tt>).
913 914 915 916 917 918 919 920 921 922
      #
      # Example:
      #   create_database config[:database], config
      #   create_database 'foo_development', :encoding => 'unicode'
      def create_database(name, options = {})
        options = options.reverse_merge(:encoding => "utf8")

        option_string = options.symbolize_keys.sum do |key, value|
          case key
          when :owner
923
            " OWNER = \"#{value}\""
924
          when :template
925
            " TEMPLATE = \"#{value}\""
926 927 928
          when :encoding
            " ENCODING = '#{value}'"
          when :tablespace
929
            " TABLESPACE = \"#{value}\""
930 931 932 933 934 935 936
          when :connection_limit
            " CONNECTION LIMIT = #{value}"
          else
            ""
          end
        end

937
        execute "CREATE DATABASE #{quote_table_name(name)}#{option_string}"
938 939
      end

940
      # Drops a PostgreSQL database.
941 942 943 944
      #
      # Example:
      #   drop_database 'matt_development'
      def drop_database(name) #:nodoc:
945
        execute "DROP DATABASE IF EXISTS #{quote_table_name(name)}"
946 947
      end

948 949
      # Returns the list of all tables in the schema search path or a specified schema.
      def tables(name = nil)
950
        query(<<-SQL, 'SCHEMA').map { |row| row[0] }
951
          SELECT tablename
952 953 954 955 956
          FROM pg_tables
          WHERE schemaname = ANY (current_schemas(false))
        SQL
      end

957
      # Returns true if table exists.
958 959
      # If the schema is not specified as part of +name+ then it will only find tables within
      # the current schema search path (regardless of permissions to access tables in other schemas)
960
      def table_exists?(name)
961
        schema, table = Utils.extract_schema_and_table(name.to_s)
962
        return false unless table
963

964 965
        binds = [[nil, table]]
        binds << [nil, schema] if schema
966 967

        exec_query(<<-SQL, 'SCHEMA', binds).rows.first[0].to_i > 0
968
            SELECT COUNT(*)
A
Aaron Patterson 已提交
969 970 971 972 973
            FROM pg_class c
            LEFT JOIN pg_namespace n ON n.oid = c.relnamespace
            WHERE c.relkind in ('v','r')
            AND c.relname = $1
            AND n.nspname = #{schema ? '$2' : 'ANY (current_schemas(false))'}
974 975 976
        SQL
      end

977 978 979 980 981 982 983 984
      # Returns true if schema exists.
      def schema_exists?(name)
        exec_query(<<-SQL, 'SCHEMA', [[nil, name]]).rows.first[0].to_i > 0
          SELECT COUNT(*)
          FROM pg_namespace
          WHERE nspname = $1
        SQL
      end
985

986
      # Returns an array of indexes for the given table.
987
      def indexes(table_name, name = nil)
988
         result = query(<<-SQL, name)
989
           SELECT distinct i.relname, d.indisunique, d.indkey, pg_get_indexdef(d.indexrelid), t.oid
990 991 992
           FROM pg_class t
           INNER JOIN pg_index d ON t.oid = d.indrelid
           INNER JOIN pg_class i ON d.indexrelid = i.oid
993 994 995
           WHERE i.relkind = 'i'
             AND d.indisprimary = 'f'
             AND t.relname = '#{table_name}'
996
             AND i.relnamespace IN (SELECT oid FROM pg_namespace WHERE nspname = ANY (current_schemas(false)) )
997 998 999
          ORDER BY i.relname
        SQL

1000

1001
        result.map do |row|
1002 1003 1004
          index_name = row[0]
          unique = row[1] == 't'
          indkey = row[2].split(" ")
1005 1006
          inddef = row[3]
          oid = row[4]
1007

1008 1009
          columns = Hash[query(<<-SQL, "Columns for index #{row[0]} on #{table_name}")]
          SELECT a.attnum, a.attname
1010 1011 1012 1013 1014
          FROM pg_attribute a
          WHERE a.attrelid = #{oid}
          AND a.attnum IN (#{indkey.join(",")})
          SQL

1015
          column_names = columns.values_at(*indkey).compact
1016 1017 1018 1019

          # add info on sort order for columns (only desc order is explicitly specified, asc is the default)
          desc_order_columns = inddef.scan(/(\w+) DESC/).flatten
          orders = desc_order_columns.any? ? Hash[desc_order_columns.map {|order_column| [order_column, :desc]}] : {}
1020
          where = inddef.scan(/WHERE (.+)$/).flatten[0]
J
Jon Leighton 已提交
1021

1022
          column_names.empty? ? nil : IndexDefinition.new(table_name, index_name, unique, column_names, [], orders, where)
1023
        end.compact
1024 1025
      end

1026
      # Returns the list of all column definitions for a table.
1027
      def columns(table_name)
1028
        # Limit, precision, and scale are all handled by the superclass.
1029 1030 1031 1032 1033
        column_definitions(table_name).map do |column_name, type, default, notnull, oid, fmod|
          oid = OID::TYPE_MAP.fetch(oid.to_i, fmod.to_i) {
            OID::Identity.new
          }
          PostgreSQLColumn.new(column_name, default, oid, type, notnull == 'f')
D
Initial  
David Heinemeier Hansson 已提交
1034 1035 1036
        end
      end

1037 1038 1039 1040 1041
      # Returns the current database name.
      def current_database
        query('select current_database()')[0][0]
      end

1042 1043 1044 1045 1046
      # Returns the current schema name.
      def current_schema
        query('SELECT current_schema', 'SCHEMA')[0][0]
      end

1047 1048 1049 1050 1051 1052 1053 1054
      # Returns the current database encoding format.
      def encoding
        query(<<-end_sql)[0][0]
          SELECT pg_encoding_to_char(pg_database.encoding) FROM pg_database
          WHERE pg_database.datname LIKE '#{current_database}'
        end_sql
      end

1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065
      # Returns an array of schema names.
      def schema_names
        query(<<-SQL).flatten
          SELECT nspname
            FROM pg_namespace
           WHERE nspname !~ '^pg_.*'
             AND nspname NOT IN ('information_schema')
           ORDER by nspname;
        SQL
      end

1066 1067 1068 1069 1070 1071 1072 1073 1074 1075
      # Creates a schema for the given schema name.
      def create_schema schema_name
        execute "CREATE SCHEMA #{schema_name}"
      end

      # Drops the schema for the given schema name.
      def drop_schema schema_name
        execute "DROP SCHEMA #{schema_name} CASCADE"
      end

1076 1077 1078 1079 1080 1081
      # Sets the schema search path to a string of comma-separated schema names.
      # Names beginning with $ have to be quoted (e.g. $user => '$user').
      # See: http://www.postgresql.org/docs/current/static/ddl-schemas.html
      #
      # This should be not be called manually but set in database.yml.
      def schema_search_path=(schema_csv)
1082
        if schema_csv
1083
          execute("SET search_path TO #{schema_csv}", 'SCHEMA')
1084
          @schema_search_path = schema_csv
1085
        end
D
Initial  
David Heinemeier Hansson 已提交
1086 1087
      end

1088 1089
      # Returns the active schema search path.
      def schema_search_path
X
Xavier Noria 已提交
1090
        @schema_search_path ||= query('SHOW search_path', 'SCHEMA')[0][0]
1091
      end
1092

1093 1094
      # Returns the current client message level.
      def client_min_messages
1095
        query('SHOW client_min_messages', 'SCHEMA')[0][0]
1096 1097 1098 1099
      end

      # Set the client message level.
      def client_min_messages=(level)
1100
        execute("SET client_min_messages TO '#{level}'", 'SCHEMA')
1101 1102 1103 1104
      end

      # Returns the sequence name for a table's primary key or some other specified key.
      def default_sequence_name(table_name, pk = nil) #:nodoc:
1105 1106 1107
        result = serial_sequence(table_name, pk || 'id')
        return nil unless result
        result.split('.').last
1108 1109 1110 1111 1112 1113 1114 1115 1116
      rescue ActiveRecord::StatementInvalid
        "#{table_name}_#{pk || 'id'}_seq"
      end

      def serial_sequence(table, column)
        result = exec_query(<<-eosql, 'SCHEMA', [[nil, table], [nil, column]])
          SELECT pg_get_serial_sequence($1, $2)
        eosql
        result.rows.first.first
1117 1118
      end

1119 1120
      # Resets the sequence of a table's primary key to the maximum value.
      def reset_pk_sequence!(table, pk = nil, sequence = nil) #:nodoc:
1121 1122
        unless pk and sequence
          default_pk, default_sequence = pk_and_sequence_for(table)
1123

1124 1125 1126
          pk ||= default_pk
          sequence ||= default_sequence
        end
1127

1128 1129 1130 1131 1132
        if @logger && pk && !sequence
          @logger.warn "#{table} has primary key #{pk} with no default sequence"
        end

        if pk && sequence
1133
          quoted_sequence = quote_table_name(sequence)
G
Guillermo Iguaran 已提交
1134

1135 1136 1137
          select_value <<-end_sql, 'Reset sequence'
            SELECT setval('#{quoted_sequence}', (SELECT COALESCE(MAX(#{quote_column_name pk})+(SELECT increment_by FROM #{quoted_sequence}), (SELECT min_value FROM #{quoted_sequence})) FROM #{quote_table_name(table)}), false)
          end_sql
1138 1139 1140
        end
      end

1141 1142
      # Returns a table's primary key and belonging sequence.
      def pk_and_sequence_for(table) #:nodoc:
1143 1144
        # First try looking for a sequence with a dependency on the
        # given table's primary key.
1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159
        result = query(<<-end_sql, 'PK and serial sequence')[0]
          SELECT attr.attname, seq.relname
          FROM pg_class      seq,
               pg_attribute  attr,
               pg_depend     dep,
               pg_namespace  name,
               pg_constraint cons
          WHERE seq.oid           = dep.objid
            AND seq.relkind       = 'S'
            AND attr.attrelid     = dep.refobjid
            AND attr.attnum       = dep.refobjsubid
            AND attr.attrelid     = cons.conrelid
            AND attr.attnum       = cons.conkey[1]
            AND cons.contype      = 'p'
            AND dep.refobjid      = '#{quote_table_name(table)}'::regclass
1160
        end_sql
1161

1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181
        if result.nil? or result.empty?
          # If that fails, try parsing the primary key's default value.
          # Support the 7.x and 8.0 nextval('foo'::text) as well as
          # the 8.1+ nextval('foo'::regclass).
          result = query(<<-end_sql, 'PK and custom sequence')[0]
            SELECT attr.attname,
              CASE
                WHEN split_part(def.adsrc, '''', 2) ~ '.' THEN
                  substr(split_part(def.adsrc, '''', 2),
                         strpos(split_part(def.adsrc, '''', 2), '.')+1)
                ELSE split_part(def.adsrc, '''', 2)
              END
            FROM pg_class       t
            JOIN pg_attribute   attr ON (t.oid = attrelid)
            JOIN pg_attrdef     def  ON (adrelid = attrelid AND adnum = attnum)
            JOIN pg_constraint  cons ON (conrelid = adrelid AND adnum = conkey[1])
            WHERE t.oid = '#{quote_table_name(table)}'::regclass
              AND cons.contype = 'p'
              AND def.adsrc ~* 'nextval'
          end_sql
1182
        end
G
Guillermo Iguaran 已提交
1183

1184
        [result.first, result.last]
1185 1186
      rescue
        nil
1187 1188
      end

1189 1190
      # Returns just a table's primary key
      def primary_key(table)
1191
        row = exec_query(<<-end_sql, 'SCHEMA', [[nil, table]]).rows.first
1192
          SELECT DISTINCT(attr.attname)
1193 1194 1195 1196 1197
          FROM pg_attribute attr
          INNER JOIN pg_depend dep ON attr.attrelid = dep.refobjid AND attr.attnum = dep.refobjsubid
          INNER JOIN pg_constraint cons ON attr.attrelid = cons.conrelid AND attr.attnum = cons.conkey[1]
          WHERE cons.contype = 'p'
            AND dep.refobjid = $1::regclass
1198 1199 1200
        end_sql

        row && row.first
1201 1202
      end

1203
      # Renames a table.
1204 1205 1206
      #
      # Example:
      #   rename_table('octopuses', 'octopi')
1207
      def rename_table(name, new_name)
1208
        clear_cache!
1209
        execute "ALTER TABLE #{quote_table_name(name)} RENAME TO #{quote_table_name(new_name)}"
1210
      end
1211

1212 1213
      # Adds a new column to the named table.
      # See TableDefinition#column for details of the options you can use.
S
Scott Barron 已提交
1214
      def add_column(table_name, column_name, type, options = {})
1215
        clear_cache!
1216 1217
        add_column_sql = "ALTER TABLE #{quote_table_name(table_name)} ADD COLUMN #{quote_column_name(column_name)} #{type_to_sql(type, options[:limit], options[:precision], options[:scale])}"
        add_column_options!(add_column_sql, options)
1218

1219
        execute add_column_sql
S
Scott Barron 已提交
1220
      end
D
Initial  
David Heinemeier Hansson 已提交
1221

1222 1223
      # Changes the column of a table.
      def change_column(table_name, column_name, type, options = {})
1224
        clear_cache!
1225 1226
        quoted_table_name = quote_table_name(table_name)

1227
        execute "ALTER TABLE #{quoted_table_name} ALTER COLUMN #{quote_column_name(column_name)} TYPE #{type_to_sql(type, options[:limit], options[:precision], options[:scale])}"
1228

1229 1230
        change_column_default(table_name, column_name, options[:default]) if options_include_default?(options)
        change_column_null(table_name, column_name, options[:null], options[:default]) if options.key?(:null)
1231
      end
1232

1233 1234
      # Changes the default value of a table column.
      def change_column_default(table_name, column_name, default)
1235
        clear_cache!
1236
        execute "ALTER TABLE #{quote_table_name(table_name)} ALTER COLUMN #{quote_column_name(column_name)} SET DEFAULT #{quote(default)}"
1237
      end
1238

1239
      def change_column_null(table_name, column_name, null, default = nil)
1240
        clear_cache!
1241
        unless null || default.nil?
1242
          execute("UPDATE #{quote_table_name(table_name)} SET #{quote_column_name(column_name)}=#{quote(default)} WHERE #{quote_column_name(column_name)} IS NULL")
1243
        end
1244
        execute("ALTER TABLE #{quote_table_name(table_name)} ALTER #{quote_column_name(column_name)} #{null ? 'DROP' : 'SET'} NOT NULL")
1245 1246
      end

1247 1248
      # Renames a column in a table.
      def rename_column(table_name, column_name, new_column_name)
1249
        clear_cache!
1250
        execute "ALTER TABLE #{quote_table_name(table_name)} RENAME COLUMN #{quote_column_name(column_name)} TO #{quote_column_name(new_column_name)}"
1251
      end
1252

1253 1254 1255 1256
      def remove_index!(table_name, index_name) #:nodoc:
        execute "DROP INDEX #{quote_table_name(index_name)}"
      end

1257 1258 1259 1260
      def rename_index(table_name, old_name, new_name)
        execute "ALTER INDEX #{quote_column_name(old_name)} RENAME TO #{quote_table_name(new_name)}"
      end

1261 1262
      def index_name_length
        63
1263
      end
1264

1265 1266
      # Maps logical Rails types to PostgreSQL-specific data types.
      def type_to_sql(type, limit = nil, precision = nil, scale = nil)
1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285
        case type.to_s
        when 'binary'
          # PostgreSQL doesn't support limits on binary (bytea) columns.
          # The hard limit is 1Gb, because of a 32-bit size field, and TOAST.
          case limit
          when nil, 0..0x3fffffff; super(type)
          else raise(ActiveRecordError, "No binary type has byte size #{limit}.")
          end
        when 'integer'
          return 'integer' unless limit
  
          case limit
            when 1, 2; 'smallint'
            when 3, 4; 'integer'
            when 5..8; 'bigint'
            else raise(ActiveRecordError, "No integer type has byte size #{limit}. Use a numeric with precision 0 instead.")
          end
        else
          super
1286 1287
        end
      end
1288

1289
      # Returns a SELECT DISTINCT clause for a given set of columns and a given ORDER BY clause.
1290 1291 1292
      #
      # PostgreSQL requires the ORDER BY columns in the select list for distinct queries, and
      # requires that the ORDER BY include the distinct column.
1293
      #
1294
      #   distinct("posts.id", "posts.created_at desc")
1295 1296
      def distinct(columns, orders) #:nodoc:
        return "DISTINCT #{columns}" if orders.empty?
1297

1298 1299
        # Construct a clean list of column names from the ORDER BY clause, removing
        # any ASC/DESC modifiers
1300 1301 1302 1303
        order_columns = orders.collect do |s|
          s = s.to_sql unless s.is_a?(String)
          s.gsub(/\s+(ASC|DESC)\s*(NULLS\s+(FIRST|LAST)\s*)?/i, '')
        end
1304
        order_columns.delete_if { |c| c.blank? }
1305
        order_columns = order_columns.zip((0...order_columns.size).to_a).map { |s,i| "#{s} AS alias_#{i}" }
1306

1307
        "DISTINCT #{columns}, #{order_columns * ', '}"
1308
      end
1309

1310
      module Utils
1311 1312
        extend self

1313 1314 1315 1316 1317 1318 1319 1320 1321 1322
        # Returns an array of <tt>[schema_name, table_name]</tt> extracted from +name+.
        # +schema_name+ is nil if not specified in +name+.
        # +schema_name+ and +table_name+ exclude surrounding quotes (regardless of whether provided in +name+)
        # +name+ supports the range of schema/table references understood by PostgreSQL, for example:
        #
        # * <tt>table_name</tt>
        # * <tt>"table.name"</tt>
        # * <tt>schema_name.table_name</tt>
        # * <tt>schema_name."table.name"</tt>
        # * <tt>"schema.name"."table name"</tt>
1323
        def extract_schema_and_table(name)
1324 1325 1326 1327 1328
          table, schema = name.scan(/[^".\s]+|"[^"]*"/)[0..1].collect{|m| m.gsub(/(^"|"$)/,'') }.reverse
          [schema, table]
        end
      end

1329 1330
      def use_insert_returning?
        @use_insert_returning
1331 1332
      end

1333
      protected
1334
        # Returns the version of the connected PostgreSQL server.
1335
        def postgresql_version
1336
          @connection.server_version
1337 1338
        end

1339 1340 1341 1342
        # See http://www.postgresql.org/docs/9.1/static/errcodes-appendix.html
        FOREIGN_KEY_VIOLATION = "23503"
        UNIQUE_VIOLATION      = "23505"

1343
        def translate_exception(exception, message)
1344 1345
          case exception.result.error_field(PGresult::PG_DIAG_SQLSTATE)
          when UNIQUE_VIOLATION
1346
            RecordNotUnique.new(message, exception)
1347
          when FOREIGN_KEY_VIOLATION
1348
            InvalidForeignKey.new(message, exception)
1349 1350 1351 1352 1353
          else
            super
          end
        end

D
Initial  
David Heinemeier Hansson 已提交
1354
      private
1355 1356 1357 1358 1359 1360 1361 1362 1363 1364 1365 1366 1367 1368 1369 1370
      def initialize_type_map
        result = execute('SELECT oid, typname, typelem, typdelim FROM pg_type', 'SCHEMA')
        leaves, nodes = result.partition { |row| row['typelem'] == '0' }

        # populate the leaf nodes
        leaves.find_all { |row| OID.registered_type? row['typname'] }.each do |row|
          OID::TYPE_MAP[row['oid'].to_i] = OID::NAMES[row['typname']]
        end

        # populate composite types
        nodes.find_all { |row| OID::TYPE_MAP.key? row['typelem'].to_i }.each do |row|
          vector = OID::Vector.new row['typdelim'], OID::TYPE_MAP[row['typelem'].to_i]
          OID::TYPE_MAP[row['oid'].to_i] = vector
        end
      end

1371 1372
        FEATURE_NOT_SUPPORTED = "0A000" # :nodoc:

1373 1374
        def exec_no_cache(sql, binds)
          @connection.async_exec(sql)
1375
        end
1376

1377
        def exec_cache(sql, binds)
1378 1379 1380 1381 1382 1383 1384 1385 1386 1387 1388 1389 1390 1391 1392
          begin
            stmt_key = prepare_statement sql

            # Clear the queue
            @connection.get_last_result
            @connection.send_query_prepared(stmt_key, binds.map { |col, val|
              type_cast(val, col)
            })
            @connection.block
            @connection.get_last_result
          rescue PGError => e
            # Get the PG code for the failure.  Annoyingly, the code for
            # prepared statements whose return value may have changed is
            # FEATURE_NOT_SUPPORTED.  Check here for more details:
            # http://git.postgresql.org/gitweb/?p=postgresql.git;a=blob;f=src/backend/utils/cache/plancache.c#l573
1393 1394 1395 1396 1397
            begin
              code = e.result.result_error_field(PGresult::PG_DIAG_SQLSTATE)
            rescue
              raise e
            end
1398 1399 1400 1401 1402 1403 1404 1405 1406 1407 1408 1409 1410 1411 1412 1413 1414 1415 1416
            if FEATURE_NOT_SUPPORTED == code
              @statements.delete sql_key(sql)
              retry
            else
              raise e
            end
          end
        end

        # Returns the statement identifier for the client side cache
        # of statements
        def sql_key(sql)
          "#{schema_search_path}-#{sql}"
        end

        # Prepare the statement if it hasn't been prepared, return
        # the statement key.
        def prepare_statement(sql)
          sql_key = sql_key(sql)
1417
          unless @statements.key? sql_key
1418
            nextkey = @statements.next_key
1419
            @connection.prepare nextkey, sql
1420
            @statements[sql_key] = nextkey
1421
          end
1422
          @statements[sql_key]
1423
        end
1424

P
Pratik Naik 已提交
1425
        # The internal PostgreSQL identifier of the money data type.
1426
        MONEY_COLUMN_TYPE_OID = 790 #:nodoc:
1427 1428
        # The internal PostgreSQL identifier of the BYTEA data type.
        BYTEA_COLUMN_TYPE_OID = 17 #:nodoc:
1429 1430 1431 1432

        # Connects to a PostgreSQL server and sets up the adapter depending on the
        # connected server's characteristics.
        def connect
1433
          @connection = PGconn.connect(@connection_parameters)
1434 1435 1436 1437

          # Money type has a fixed precision of 10 in PostgreSQL 8.2 and below, and as of
          # PostgreSQL 8.3 it has a fixed precision of 19. PostgreSQLColumn.extract_precision
          # should know about this but can't detect it there, so deal with it here.
1438 1439
          PostgreSQLColumn.money_precision = (postgresql_version >= 80300) ? 19 : 10

1440 1441 1442
          configure_connection
        end

1443
        # Configures the encoding, verbosity, schema search path, and time zone of the connection.
1444
        # This is called by #connect and should not be called manually.
1445 1446
        def configure_connection
          if @config[:encoding]
1447
            @connection.set_client_encoding(@config[:encoding])
1448
          end
1449 1450
          self.client_min_messages = @config[:min_messages] if @config[:min_messages]
          self.schema_search_path = @config[:schema_search_path] || @config[:schema_order]
1451 1452 1453 1454

          # Use standard-conforming strings if available so we don't have to do the E'...' dance.
          set_standard_conforming_strings

1455
          # If using Active Record's time zone support configure the connection to return
1456
          # TIMESTAMP WITH ZONE types in UTC.
1457
          if ActiveRecord::Base.default_timezone == :utc
1458
            execute("SET time zone 'UTC'", 'SCHEMA')
1459
          elsif @local_tz
1460
            execute("SET time zone '#{@local_tz}'", 'SCHEMA')
1461
          end
1462 1463
        end

1464
        # Returns the current ID of a table's sequence.
1465
        def last_insert_id(sequence_name) #:nodoc:
1466 1467 1468
          Integer(last_insert_id_value(sequence_name))
        end

D
Doug Cole 已提交
1469 1470 1471 1472 1473 1474
        def last_insert_id_value(sequence_name)
          last_insert_id_result(sequence_name).rows.first.first
        end

        def last_insert_id_result(sequence_name) #:nodoc:
          exec_query("SELECT currval($1)", 'SQL', [[nil, sequence_name]])
D
Initial  
David Heinemeier Hansson 已提交
1475 1476
        end

1477
        # Executes a SELECT query and returns the results, performing any data type
1478
        # conversions that are required to be performed here instead of in PostgreSQLColumn.
1479
        def select(sql, name = nil, binds = [])
1480
          exec_query(sql, name, binds)
1481 1482 1483
        end

        def select_raw(sql, name = nil)
1484
          res = execute(sql, name)
1485
          results = result_as_array(res)
1486
          fields = res.fields
1487
          res.clear
1488
          return fields, results
M
Marcel Molina 已提交
1489 1490
        end

1491
        # Returns the list of a table's column names, data types, and default values.
1492 1493 1494 1495 1496 1497 1498 1499 1500 1501 1502 1503 1504 1505 1506 1507 1508
        #
        # The underlying query is roughly:
        #  SELECT column.name, column.type, default.value
        #    FROM column LEFT JOIN default
        #      ON column.table_id = default.table_id
        #     AND column.num = default.column_num
        #   WHERE column.table_id = get_table_id('table_name')
        #     AND column.num > 0
        #     AND NOT column.is_dropped
        #   ORDER BY column.num
        #
        # If the table name is not prefixed with a schema, the database will
        # take the first match from the schema search path.
        #
        # Query implementation notes:
        #  - format_type includes the column size constraint, e.g. varchar(50)
        #  - ::regclass is a function that gives the id for a table name
1509
        def column_definitions(table_name) #:nodoc:
1510
          exec_query(<<-end_sql, 'SCHEMA').rows
1511
            SELECT a.attname, format_type(a.atttypid, a.atttypmod), d.adsrc, a.attnotnull, a.atttypid, a.atttypmod
1512 1513
              FROM pg_attribute a LEFT JOIN pg_attrdef d
                ON a.attrelid = d.adrelid AND a.attnum = d.adnum
1514
             WHERE a.attrelid = '#{quote_table_name(table_name)}'::regclass
1515 1516 1517
               AND a.attnum > 0 AND NOT a.attisdropped
             ORDER BY a.attnum
          end_sql
D
Initial  
David Heinemeier Hansson 已提交
1518
        end
1519 1520

        def extract_pg_identifier_from_name(name)
1521
          match_data = name.start_with?('"') ? name.match(/\"([^\"]+)\"/) : name.match(/([^\.]+)/)
1522 1523

          if match_data
1524 1525
            rest = name[match_data[0].length, name.length]
            rest = rest[1, rest.length] if rest.start_with? "."
J
José Valim 已提交
1526
            [match_data[1], (rest.length > 0 ? rest : nil)]
1527 1528
          end
        end
1529

1530 1531 1532 1533 1534
        def extract_table_ref_from_insert_sql(sql)
          sql[/into\s+([^\(]*).*values\s*\(/i]
          $1.strip if $1
        end

1535 1536 1537
        def table_definition
          TableDefinition.new(self)
        end
D
Initial  
David Heinemeier Hansson 已提交
1538 1539 1540
    end
  end
end