postgresql_adapter.rb 39.5 KB
Newer Older
D
Initial  
David Heinemeier Hansson 已提交
1
require 'active_record/connection_adapters/abstract_adapter'
2
require 'active_support/core_ext/object/blank'
3
require 'active_record/connection_adapters/statement_pool'
4 5 6

# Make sure we're using pg high enough for PGResult#values
gem 'pg', '~> 0.11'
7
require 'pg'
D
Initial  
David Heinemeier Hansson 已提交
8 9 10 11 12

module ActiveRecord
  class Base
    # Establishes a connection to the database that's used by all Active Record objects
    def self.postgresql_connection(config) # :nodoc:
13
      config = config.symbolize_keys
D
Initial  
David Heinemeier Hansson 已提交
14
      host     = config[:host]
15
      port     = config[:port] || 5432
16 17
      username = config[:username].to_s if config[:username]
      password = config[:password].to_s if config[:password]
D
Initial  
David Heinemeier Hansson 已提交
18

19
      if config.key?(:database)
D
Initial  
David Heinemeier Hansson 已提交
20 21 22 23 24
        database = config[:database]
      else
        raise ArgumentError, "No database specified. Missing argument: database."
      end

25
      # The postgres drivers don't allow the creation of an unconnected PGconn object,
26 27 28 29
      # so just pass a nil connection object for the time being.
      ConnectionAdapters::PostgreSQLAdapter.new(nil, logger, [host, port, nil, nil, database, username, password], config)
    end
  end
30

31 32 33 34 35 36 37
  module ConnectionAdapters
    # PostgreSQL-specific extensions to column definitions in a table.
    class PostgreSQLColumn < Column #:nodoc:
      # Instantiates a new PostgreSQL column definition in a table.
      def initialize(name, default, sql_type = nil, null = true)
        super(name, self.class.extract_value_from_default(default), sql_type, null)
      end
38

39 40 41
      # :stopdoc:
      class << self
        attr_accessor :money_precision
42 43 44 45 46 47 48 49 50 51
        def string_to_time(string)
          return string unless String === string

          case string
          when 'infinity'  then 1.0 / 0.0
          when '-infinity' then -1.0 / 0.0
          else
            super
          end
        end
52 53 54
      end
      # :startdoc:

55
      private
56
        def extract_limit(sql_type)
57 58 59 60 61
          case sql_type
          when /^bigint/i;    8
          when /^smallint/i;  2
          else super
          end
62 63
        end

64 65 66 67 68
        # Extracts the scale from PostgreSQL-specific data types.
        def extract_scale(sql_type)
          # Money type has a fixed scale of 2.
          sql_type =~ /^money/ ? 2 : super
        end
69

70 71
        # Extracts the precision from PostgreSQL-specific data types.
        def extract_precision(sql_type)
72 73 74 75 76
          if sql_type == 'money'
            self.class.money_precision
          else
            super
          end
77
        end
78

79 80 81 82 83 84 85
        # Maps PostgreSQL-specific data types to logical Rails types.
        def simplified_type(field_type)
          case field_type
            # Numeric and monetary types
            when /^(?:real|double precision)$/
              :float
            # Monetary types
86
            when 'money'
87 88 89 90 91
              :decimal
            # Character types
            when /^(?:character varying|bpchar)(?:\(\d+\))?$/
              :string
            # Binary data types
92
            when 'bytea'
93 94 95 96
              :binary
            # Date/time types
            when /^timestamp with(?:out)? time zone$/
              :datetime
97
            when 'interval'
98 99 100 101 102 103 104 105 106 107 108
              :string
            # Geometric types
            when /^(?:point|line|lseg|box|"?path"?|polygon|circle)$/
              :string
            # Network address types
            when /^(?:cidr|inet|macaddr)$/
              :string
            # Bit strings
            when /^bit(?: varying)?(?:\(\d+\))?$/
              :string
            # XML type
109
            when 'xml'
110
              :xml
111 112 113
            # tsvector type
            when 'tsvector'
              :tsvector
114 115
            # Arrays
            when /^\D+\[\]$/
116
              :string
117
            # Object identifier types
118
            when 'oid'
119
              :integer
120
            # UUID type
121
            when 'uuid'
122 123 124 125
              :string
            # Small and big integer types
            when /^(?:small|big)int$/
              :integer
126 127 128 129 130
            # Pass through all types that are not specific to PostgreSQL.
            else
              super
          end
        end
131

132 133 134
        # Extracts the value from a PostgreSQL column default definition.
        def self.extract_value_from_default(default)
          case default
135 136 137 138 139 140 141 142
            # This is a performance optimization for Ruby 1.9.2 in development.
            # If the value is nil, we return nil straight away without checking
            # the regular expressions. If we check each regular expression,
            # Regexp#=== will call NilClass#to_str, which will trigger
            # method_missing (defined by whiny nil in ActiveSupport) which
            # makes this method very very slow.
            when NilClass
              nil
143
            # Numeric types
144 145
            when /\A\(?(-?\d+(\.\d*)?\)?)\z/
              $1
146
            # Character types
147
            when /\A'(.*)'::(?:character varying|bpchar|text)\z/m
148
              $1
149 150 151
            # Character types (8.1 formatting)
            when /\AE'(.*)'::(?:character varying|bpchar|text)\z/m
              $1.gsub(/\\(\d\d\d)/) { $1.oct.chr }
152
            # Binary data types
153
            when /\A'(.*)'::bytea\z/m
154 155
              $1
            # Date/time types
156
            when /\A'(.+)'::(?:time(?:stamp)? with(?:out)? time zone|date)\z/
157
              $1
158
            when /\A'(.*)'::interval\z/
159 160
              $1
            # Boolean type
161
            when 'true'
162
              true
163
            when 'false'
164 165
              false
            # Geometric types
166
            when /\A'(.*)'::(?:point|line|lseg|box|"?path"?|polygon|circle)\z/
167 168
              $1
            # Network address types
169
            when /\A'(.*)'::(?:cidr|inet|macaddr)\z/
170 171
              $1
            # Bit string types
172
            when /\AB'(.*)'::"?bit(?: varying)?"?\z/
173 174
              $1
            # XML type
175
            when /\A'(.*)'::xml\z/m
176 177
              $1
            # Arrays
178
            when /\A'(.*)'::"?\D+"?\[\]\z/
179 180
              $1
            # Object identifier types
181
            when /\A-?\d+\z/
182 183 184
              $1
            else
              # Anything else is blank, some user type, or some function
185
              # and we can't know the value of that, so return nil.
186 187 188
              nil
          end
        end
D
Initial  
David Heinemeier Hansson 已提交
189 190
    end

191 192
    # The PostgreSQL adapter works both with the native C (http://ruby.scripting.ca/postgres/) and the pure
    # Ruby (available both as gem and from http://rubyforge.org/frs/?group_id=234&release_id=1944) drivers.
193 194 195
    #
    # Options:
    #
P
Pratik Naik 已提交
196 197 198 199 200
    # * <tt>:host</tt> - Defaults to "localhost".
    # * <tt>:port</tt> - Defaults to 5432.
    # * <tt>:username</tt> - Defaults to nothing.
    # * <tt>:password</tt> - Defaults to nothing.
    # * <tt>:database</tt> - The name of the database. No default, must be provided.
201
    # * <tt>:schema_search_path</tt> - An optional schema search path for the connection given
202
    #   as a string of comma-separated schema names. This is backward-compatible with the <tt>:schema_order</tt> option.
203
    # * <tt>:encoding</tt> - An optional client encoding that is used in a <tt>SET client_encoding TO
204
    #   <encoding></tt> call on the connection.
205
    # * <tt>:min_messages</tt> - An optional client min messages that is used in a
206
    #   <tt>SET client_min_messages TO <min_messages></tt> call on the connection.
207
    class PostgreSQLAdapter < AbstractAdapter
208 209 210 211 212
      class TableDefinition < ActiveRecord::ConnectionAdapters::TableDefinition
        def xml(*args)
          options = args.extract_options!
          column(args[0], 'xml', options)
        end
213 214 215 216 217

        def tsvector(*args)
          options = args.extract_options!
          column(args[0], 'tsvector', options)
        end
218 219
      end

220
      ADAPTER_NAME = 'PostgreSQL'
221 222

      NATIVE_DATABASE_TYPES = {
223
        :primary_key => "serial primary key",
224 225 226 227 228 229 230 231 232 233
        :string      => { :name => "character varying", :limit => 255 },
        :text        => { :name => "text" },
        :integer     => { :name => "integer" },
        :float       => { :name => "float" },
        :decimal     => { :name => "decimal" },
        :datetime    => { :name => "timestamp" },
        :timestamp   => { :name => "timestamp" },
        :time        => { :name => "time" },
        :date        => { :name => "date" },
        :binary      => { :name => "bytea" },
234
        :boolean     => { :name => "boolean" },
235 236
        :xml         => { :name => "xml" },
        :tsvector    => { :name => "tsvector" }
237 238
      }

239
      # Returns 'PostgreSQL' as adapter name for identification purposes.
240
      def adapter_name
241
        ADAPTER_NAME
242 243
      end

244 245
      # Returns +true+, since this connection adapter supports prepared statement
      # caching.
246 247 248 249
      def supports_statement_cache?
        true
      end

250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286
      class StatementPool < ConnectionAdapters::StatementPool
        def initialize(connection, max)
          super
          @counter = 0
          @cache   = {}
        end

        def each(&block); @cache.each(&block); end
        def key?(key);    @cache.key?(key); end
        def [](key);      @cache[key]; end
        def length;       @cache.length; end

        def next_key
          "a#{@counter + 1}"
        end

        def []=(sql, key)
          while @max <= @cache.size
            dealloc(@cache.shift.last)
          end
          @counter += 1
          @cache[sql] = key
        end

        def clear
          @cache.each_value do |stmt_key|
            dealloc stmt_key
          end
          @cache.clear
        end

        private
        def dealloc(key)
          @connection.query "DEALLOCATE #{key}"
        end
      end

287 288
      # Initializes and connects a PostgreSQL adapter.
      def initialize(connection, logger, connection_parameters, config)
289
        super(connection, logger)
290
        @connection_parameters, @config = connection_parameters, config
291

292 293
        # @local_tz is initialized as nil to avoid warnings when connect tries to use it
        @local_tz = nil
294 295
        @table_alias_length = nil

296
        connect
297 298
        @statements = StatementPool.new @connection,
                                        config.fetch(:statement_limit) { 1000 }
299 300 301 302 303

        if postgresql_version < 80200
          raise "Your version of PostgreSQL (#{postgresql_version}) is too old, please upgrade!"
        end

304
        @local_tz = execute('SHOW TIME ZONE', 'SCHEMA').first["TimeZone"]
305 306
      end

307 308 309 310
      def self.visitor_for(pool) # :nodoc:
        Arel::Visitors::PostgreSQL.new(pool)
      end

X
Xavier Noria 已提交
311
      # Clears the prepared statements cache.
312 313 314 315
      def clear_cache!
        @statements.clear
      end

316 317
      # Is this connection alive and ready for queries?
      def active?
318 319
        @connection.status == PGconn::CONNECTION_OK
      rescue PGError
320
        false
321 322 323 324
      end

      # Close then reopen the connection.
      def reconnect!
325 326 327
        clear_cache!
        @connection.reset
        configure_connection
328
      end
329

330 331 332 333 334
      def reset!
        clear_cache!
        super
      end

335 336
      # Disconnects from the database if already connected. Otherwise, this
      # method does nothing.
337
      def disconnect!
338
        clear_cache!
339 340
        @connection.close rescue nil
      end
341

342
      def native_database_types #:nodoc:
343
        NATIVE_DATABASE_TYPES
344
      end
345

346
      # Returns true, since this connection adapter supports migrations.
347 348
      def supports_migrations?
        true
349 350
      end

351
      # Does PostgreSQL support finding primary key on non-Active Record tables?
352 353 354 355
      def supports_primary_key? #:nodoc:
        true
      end

356 357 358
      # Enable standard-conforming strings if available.
      def set_standard_conforming_strings
        old, self.client_min_messages = client_min_messages, 'panic'
359
        execute('SET standard_conforming_strings = on', 'SCHEMA') rescue nil
360 361
      ensure
        self.client_min_messages = old
362 363
      end

364
      def supports_insert_with_returning?
365
        true
366 367
      end

368 369 370
      def supports_ddl_transactions?
        true
      end
371

372
      # Returns true, since this connection adapter supports savepoints.
373 374 375
      def supports_savepoints?
        true
      end
376

377
      # Returns the configured supported identifier length supported by PostgreSQL
378
      def table_alias_length
379
        @table_alias_length ||= query('SHOW max_identifier_length')[0][0].to_i
380
      end
381

382 383
      # QUOTING ==================================================

384
      # Escapes binary strings for bytea input to the database.
385 386
      def escape_bytea(value)
        @connection.escape_bytea(value) if value
387 388 389 390 391
      end

      # Unescapes bytea output from a database to the binary string it represents.
      # NOTE: This is NOT an inverse of escape_bytea! This is only to be used
      #       on escaped binary output from database drive.
392 393
      def unescape_bytea(value)
        @connection.unescape_bytea(value) if value
394 395
      end

396 397
      # Quotes PostgreSQL-specific data types for SQL input.
      def quote(value, column = nil) #:nodoc:
398 399
        return super unless column

A
Aaron Patterson 已提交
400
        case value
401 402 403
        when Float
          return super unless value.infinite? && column.type == :datetime
          "'#{value.to_s.downcase}'"
A
Aaron Patterson 已提交
404 405
        when Numeric
          return super unless column.sql_type == 'money'
406
          # Not truly string input, so doesn't require (or allow) escape string syntax.
407
          "'#{value}'"
A
Aaron Patterson 已提交
408 409 410 411 412 413 414 415 416 417 418
        when String
          case column.sql_type
          when 'bytea' then "'#{escape_bytea(value)}'"
          when 'xml'   then "xml '#{quote_string(value)}'"
          when /^bit/
            case value
            when /^[01]*$/      then "B'#{value}'" # Bit-string notation
            when /^[0-9A-F]*$/i then "X'#{value}'" # Hexadecimal notation
            end
          else
            super
419
          end
420 421 422 423 424
        else
          super
        end
      end

425 426 427 428 429 430
      def type_cast(value, column)
        return super unless column

        case value
        when String
          return super unless 'bytea' == column.sql_type
431
          { :value => value, :format => 1 }
432 433 434 435 436
        else
          super
        end
      end

437 438 439
      # Quotes strings for use in SQL input.
      def quote_string(s) #:nodoc:
        @connection.escape(s)
440 441
      end

442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460
      # Checks the following cases:
      #
      # - table_name
      # - "table.name"
      # - schema_name.table_name
      # - schema_name."table.name"
      # - "schema.name".table_name
      # - "schema.name"."table.name"
      def quote_table_name(name)
        schema, name_part = extract_pg_identifier_from_name(name.to_s)

        unless name_part
          quote_column_name(schema)
        else
          table_name, name_part = extract_pg_identifier_from_name(name_part)
          "#{quote_column_name(schema)}.#{quote_column_name(table_name)}"
        end
      end

461 462
      # Quotes column names for use in SQL queries.
      def quote_column_name(name) #:nodoc:
463
        PGconn.quote_ident(name.to_s)
464 465
      end

466 467 468
      # Quote date/time values for use in SQL input. Includes microseconds
      # if the value is a Time responding to usec.
      def quoted_date(value) #:nodoc:
469 470 471 472 473
        if value.acts_like?(:time) && value.respond_to?(:usec)
          "#{super}.#{sprintf("%06d", value.usec)}"
        else
          super
        end
474 475
      end

476 477
      # Set the authorized user for this session
      def session_auth=(user)
478
        clear_cache!
A
Aaron Patterson 已提交
479
        exec_query "SET SESSION AUTHORIZATION #{user}"
480 481
      end

482 483
      # REFERENTIAL INTEGRITY ====================================

484
      def supports_disable_referential_integrity? #:nodoc:
485
        true
486 487
      end

488
      def disable_referential_integrity #:nodoc:
489
        if supports_disable_referential_integrity? then
490 491
          execute(tables.collect { |name| "ALTER TABLE #{quote_table_name(name)} DISABLE TRIGGER ALL" }.join(";"))
        end
492 493
        yield
      ensure
494
        if supports_disable_referential_integrity? then
495 496
          execute(tables.collect { |name| "ALTER TABLE #{quote_table_name(name)} ENABLE TRIGGER ALL" }.join(";"))
        end
497
      end
498 499 500

      # DATABASE STATEMENTS ======================================

501 502 503 504 505 506
      # Executes a SELECT query and returns an array of rows. Each row is an
      # array of field values.
      def select_rows(sql, name = nil)
        select_raw(sql, name).last
      end

507
      # Executes an INSERT query and returns the new record's ID
508
      def insert_sql(sql, name = nil, pk = nil, id_value = nil, sequence_name = nil)
509 510 511 512 513
        unless pk
          # Extract the table from the insert sql. Yuck.
          table_ref = extract_table_ref_from_insert_sql(sql)
          pk = primary_key(table_ref) if table_ref
        end
514

515
        if pk
516 517 518
          select_value("#{sql} RETURNING #{quote_column_name(pk)}")
        else
          super
519
        end
520
      end
521
      alias :create :insert
522

523 524
      # create a 2D array representing the result set
      def result_as_array(res) #:nodoc:
525
        # check if we have any binary column and if they need escaping
526 527
        ftypes = Array.new(res.nfields) do |i|
          [i, res.ftype(i)]
528 529
        end

530 531 532 533 534 535
        rows = res.values
        return rows unless ftypes.any? { |_, x|
          x == BYTEA_COLUMN_TYPE_OID || x == MONEY_COLUMN_TYPE_OID
        }

        typehash = ftypes.group_by { |_, type| type }
536 537
        binaries = typehash[BYTEA_COLUMN_TYPE_OID] || []
        monies   = typehash[MONEY_COLUMN_TYPE_OID] || []
538 539 540

        rows.each do |row|
          # unescape string passed BYTEA field (OID == 17)
541 542
          binaries.each do |index, _|
            row[index] = unescape_bytea(row[index])
543 544 545 546 547 548
          end

          # If this is a money type column and there are any currency symbols,
          # then strip them off. Indeed it would be prettier to do this in
          # PostgreSQLColumn.string_to_decimal but would break form input
          # fields that call value_before_type_cast.
549
          monies.each do |index, _|
550 551 552 553 554 555 556 557 558 559
            data = row[index]
            # Because money output is formatted according to the locale, there are two
            # cases to consider (note the decimal separators):
            #  (1) $12,345,678.12
            #  (2) $12.345.678,12
            case data
            when /^-?\D+[\d,]+\.\d{2}$/  # (1)
              data.gsub!(/[^-\d.]/, '')
            when /^-?\D+[\d.]+,\d{2}$/  # (2)
              data.gsub!(/[^-\d,]/, '').sub!(/,/, '.')
560
            end
561 562 563 564 565 566
          end
        end
      end


      # Queries the database and returns the results in an Array-like object
567
      def query(sql, name = nil) #:nodoc:
568
        log(sql, name) do
569
          result_as_array @connection.async_exec(sql)
570
        end
571 572
      end

573
      # Executes an SQL statement, returning a PGresult object on success
574 575
      # or raising a PGError exception otherwise.
      def execute(sql, name = nil)
576
        log(sql, name) do
577
          @connection.async_exec(sql)
578
        end
579 580
      end

581 582
      def substitute_at(column, index)
        Arel.sql("$#{index + 1}")
583 584
      end

A
Aaron Patterson 已提交
585
      def exec_query(sql, name = 'SQL', binds = [])
586
        log(sql, name, binds) do
587 588
          result = binds.empty? ? exec_no_cache(sql, binds) :
                                  exec_cache(sql, binds)
589

590 591 592
          ret = ActiveRecord::Result.new(result.fields, result_as_array(result))
          result.clear
          return ret
593 594 595
        end
      end

596 597 598 599 600 601 602 603 604
      def exec_delete(sql, name = 'SQL', binds = [])
        log(sql, name, binds) do
          result = binds.empty? ? exec_no_cache(sql, binds) :
                                  exec_cache(sql, binds)
          affected = result.cmd_tuples
          result.clear
          affected
        end
      end
605
      alias :exec_update :exec_delete
606

607 608
      def sql_for_insert(sql, pk, id_value, sequence_name, binds)
        unless pk
609 610 611
          # Extract the table from the insert sql. Yuck.
          table_ref = extract_table_ref_from_insert_sql(sql)
          pk = primary_key(table_ref) if table_ref
612 613 614 615 616 617 618
        end

        sql = "#{sql} RETURNING #{quote_column_name(pk)}" if pk

        [sql, binds]
      end

619
      # Executes an UPDATE query and returns the number of affected tuples.
620
      def update_sql(sql, name = nil)
621
        super.cmd_tuples
622 623
      end

624 625
      # Begins a transaction.
      def begin_db_transaction
626 627 628
        execute "BEGIN"
      end

629 630
      # Commits a transaction.
      def commit_db_transaction
631 632
        execute "COMMIT"
      end
633

634 635
      # Aborts a transaction.
      def rollback_db_transaction
636 637
        execute "ROLLBACK"
      end
638

639 640
      def outside_transaction?
        @connection.transaction_status == PGconn::PQTRANS_IDLE
641
      end
642

J
Jonathan Viney 已提交
643 644 645 646 647 648 649 650
      def create_savepoint
        execute("SAVEPOINT #{current_savepoint_name}")
      end

      def rollback_to_savepoint
        execute("ROLLBACK TO SAVEPOINT #{current_savepoint_name}")
      end

651
      def release_savepoint
J
Jonathan Viney 已提交
652 653
        execute("RELEASE SAVEPOINT #{current_savepoint_name}")
      end
654

655 656
      # SCHEMA STATEMENTS ========================================

657 658 659
      # Drops the database specified on the +name+ attribute
      # and creates it again using the provided +options+.
      def recreate_database(name, options = {}) #:nodoc:
660
        drop_database(name)
661
        create_database(name, options)
662 663
      end

664
      # Create a new PostgreSQL database. Options include <tt>:owner</tt>, <tt>:template</tt>,
665 666
      # <tt>:encoding</tt>, <tt>:tablespace</tt>, and <tt>:connection_limit</tt> (note that MySQL uses
      # <tt>:charset</tt> while PostgreSQL uses <tt>:encoding</tt>).
667 668 669 670 671 672 673 674 675 676
      #
      # Example:
      #   create_database config[:database], config
      #   create_database 'foo_development', :encoding => 'unicode'
      def create_database(name, options = {})
        options = options.reverse_merge(:encoding => "utf8")

        option_string = options.symbolize_keys.sum do |key, value|
          case key
          when :owner
677
            " OWNER = \"#{value}\""
678
          when :template
679
            " TEMPLATE = \"#{value}\""
680 681 682
          when :encoding
            " ENCODING = '#{value}'"
          when :tablespace
683
            " TABLESPACE = \"#{value}\""
684 685 686 687 688 689 690
          when :connection_limit
            " CONNECTION LIMIT = #{value}"
          else
            ""
          end
        end

691
        execute "CREATE DATABASE #{quote_table_name(name)}#{option_string}"
692 693
      end

694
      # Drops a PostgreSQL database.
695 696 697 698
      #
      # Example:
      #   drop_database 'matt_development'
      def drop_database(name) #:nodoc:
699
        execute "DROP DATABASE IF EXISTS #{quote_table_name(name)}"
700 701
      end

702 703
      # Returns the list of all tables in the schema search path or a specified schema.
      def tables(name = nil)
704
        query(<<-SQL, 'SCHEMA').map { |row| row[0] }
705
          SELECT tablename
706 707 708 709 710
          FROM pg_tables
          WHERE schemaname = ANY (current_schemas(false))
        SQL
      end

711
      # Returns true if table exists.
712 713
      # If the schema is not specified as part of +name+ then it will only find tables within
      # the current schema search path (regardless of permissions to access tables in other schemas)
714
      def table_exists?(name)
715
        schema, table = Utils.extract_schema_and_table(name.to_s)
716
        return false unless table
717

718 719
        binds = [[nil, table]]
        binds << [nil, schema] if schema
720 721

        exec_query(<<-SQL, 'SCHEMA', binds).rows.first[0].to_i > 0
722
          SELECT COUNT(*)
723 724 725 726 727
          FROM pg_class c
          LEFT JOIN pg_namespace n ON n.oid = c.relnamespace
          WHERE c.relkind in ('v','r')
          AND c.relname = $1
          AND n.nspname = #{schema ? '$2' : 'ANY (current_schemas(false))'}
728 729 730
        SQL
      end

731 732 733 734 735 736 737 738
      # Returns true if schema exists.
      def schema_exists?(name)
        exec_query(<<-SQL, 'SCHEMA', [[nil, name]]).rows.first[0].to_i > 0
          SELECT COUNT(*)
          FROM pg_namespace
          WHERE nspname = $1
        SQL
      end
739

740
      # Returns an array of indexes for the given table.
741
      def indexes(table_name, name = nil)
742 743
         schemas = schema_search_path.split(/,/).map { |p| quote(p) }.join(',')
         result = query(<<-SQL, name)
744
           SELECT distinct i.relname, d.indisunique, d.indkey, t.oid
745 746 747
           FROM pg_class t
           INNER JOIN pg_index d ON t.oid = d.indrelid
           INNER JOIN pg_class i ON d.indexrelid = i.oid
748 749 750
           WHERE i.relkind = 'i'
             AND d.indisprimary = 'f'
             AND t.relname = '#{table_name}'
751
             AND i.relnamespace IN (SELECT oid FROM pg_namespace WHERE nspname IN (#{schemas}) )
752 753 754
          ORDER BY i.relname
        SQL

755

756
        result.map do |row|
757 758 759 760 761
          index_name = row[0]
          unique = row[1] == 't'
          indkey = row[2].split(" ")
          oid = row[3]

762 763
          columns = Hash[query(<<-SQL, "Columns for index #{row[0]} on #{table_name}")]
          SELECT a.attnum, a.attname
764 765 766 767 768
          FROM pg_attribute a
          WHERE a.attrelid = #{oid}
          AND a.attnum IN (#{indkey.join(",")})
          SQL

769 770
          column_names = columns.values_at(*indkey).compact
          column_names.empty? ? nil : IndexDefinition.new(table_name, index_name, unique, column_names)
771
        end.compact
772 773
      end

774 775
      # Returns the list of all column definitions for a table.
      def columns(table_name, name = nil)
776
        # Limit, precision, and scale are all handled by the superclass.
777 778
        column_definitions(table_name).collect do |column_name, type, default, notnull|
          PostgreSQLColumn.new(column_name, default, type, notnull == 'f')
D
Initial  
David Heinemeier Hansson 已提交
779 780 781
        end
      end

782 783 784 785 786
      # Returns the current database name.
      def current_database
        query('select current_database()')[0][0]
      end

787 788 789 790 791
      # Returns the current schema name.
      def current_schema
        query('SELECT current_schema', 'SCHEMA')[0][0]
      end

792 793 794 795 796 797 798 799
      # Returns the current database encoding format.
      def encoding
        query(<<-end_sql)[0][0]
          SELECT pg_encoding_to_char(pg_database.encoding) FROM pg_database
          WHERE pg_database.datname LIKE '#{current_database}'
        end_sql
      end

800 801 802 803 804 805
      # Sets the schema search path to a string of comma-separated schema names.
      # Names beginning with $ have to be quoted (e.g. $user => '$user').
      # See: http://www.postgresql.org/docs/current/static/ddl-schemas.html
      #
      # This should be not be called manually but set in database.yml.
      def schema_search_path=(schema_csv)
806 807
        if schema_csv
          execute "SET search_path TO #{schema_csv}"
808
          @schema_search_path = schema_csv
809
        end
D
Initial  
David Heinemeier Hansson 已提交
810 811
      end

812 813
      # Returns the active schema search path.
      def schema_search_path
814
        @schema_search_path ||= query('SHOW search_path')[0][0]
815
      end
816

817 818
      # Returns the current client message level.
      def client_min_messages
819
        query('SHOW client_min_messages', 'SCHEMA')[0][0]
820 821 822 823
      end

      # Set the client message level.
      def client_min_messages=(level)
824
        execute("SET client_min_messages TO '#{level}'", 'SCHEMA')
825 826 827 828
      end

      # Returns the sequence name for a table's primary key or some other specified key.
      def default_sequence_name(table_name, pk = nil) #:nodoc:
829 830 831 832 833 834 835 836 837 838
        serial_sequence(table_name, pk || 'id').split('.').last
      rescue ActiveRecord::StatementInvalid
        "#{table_name}_#{pk || 'id'}_seq"
      end

      def serial_sequence(table, column)
        result = exec_query(<<-eosql, 'SCHEMA', [[nil, table], [nil, column]])
          SELECT pg_get_serial_sequence($1, $2)
        eosql
        result.rows.first.first
839 840
      end

841 842
      # Resets the sequence of a table's primary key to the maximum value.
      def reset_pk_sequence!(table, pk = nil, sequence = nil) #:nodoc:
843 844
        unless pk and sequence
          default_pk, default_sequence = pk_and_sequence_for(table)
845

846 847 848
          pk ||= default_pk
          sequence ||= default_sequence
        end
849

850 851 852 853 854
        if @logger && pk && !sequence
          @logger.warn "#{table} has primary key #{pk} with no default sequence"
        end

        if pk && sequence
855
          quoted_sequence = quote_table_name(sequence)
G
Guillermo Iguaran 已提交
856

857 858 859
          select_value <<-end_sql, 'Reset sequence'
            SELECT setval('#{quoted_sequence}', (SELECT COALESCE(MAX(#{quote_column_name pk})+(SELECT increment_by FROM #{quoted_sequence}), (SELECT min_value FROM #{quoted_sequence})) FROM #{quote_table_name(table)}), false)
          end_sql
860 861 862
        end
      end

863 864
      # Returns a table's primary key and belonging sequence.
      def pk_and_sequence_for(table) #:nodoc:
865 866
        # First try looking for a sequence with a dependency on the
        # given table's primary key.
867
        result = exec_query(<<-end_sql, 'SCHEMA').rows.first
868
          SELECT attr.attname, ns.nspname, seq.relname
869
          FROM pg_class seq
A
Akira Matsuda 已提交
870
          INNER JOIN pg_depend dep ON seq.oid = dep.objid
871 872
          INNER JOIN pg_attribute attr ON attr.attrelid = dep.refobjid AND attr.attnum = dep.refobjsubid
          INNER JOIN pg_constraint cons ON attr.attrelid = cons.conrelid AND attr.attnum = cons.conkey[1]
873
          INNER JOIN pg_namespace ns ON seq.relnamespace = ns.oid
874 875 876
          WHERE seq.relkind  = 'S'
            AND cons.contype = 'p'
            AND dep.refobjid = '#{quote_table_name(table)}'::regclass
877
        end_sql
878

879
        # [primary_key, sequence]
880 881 882 883 884
        if result.second ==  'public' then
          sequence = result.last
        else
          sequence = result.second+'.'+result.last
        end
G
Guillermo Iguaran 已提交
885

886
        [result.first, sequence]
887 888
      rescue
        nil
889 890
      end

891 892
      # Returns just a table's primary key
      def primary_key(table)
893
        row = exec_query(<<-end_sql, 'SCHEMA', [[nil, table]]).rows.first
894
          SELECT DISTINCT(attr.attname)
895 896 897 898 899
          FROM pg_attribute attr
          INNER JOIN pg_depend dep ON attr.attrelid = dep.refobjid AND attr.attnum = dep.refobjsubid
          INNER JOIN pg_constraint cons ON attr.attrelid = cons.conrelid AND attr.attnum = cons.conkey[1]
          WHERE cons.contype = 'p'
            AND dep.refobjid = $1::regclass
900 901 902
        end_sql

        row && row.first
903 904
      end

905
      # Renames a table.
906 907 908
      #
      # Example:
      #   rename_table('octopuses', 'octopi')
909
      def rename_table(name, new_name)
910
        execute "ALTER TABLE #{quote_table_name(name)} RENAME TO #{quote_table_name(new_name)}"
911
      end
912

913 914
      # Adds a new column to the named table.
      # See TableDefinition#column for details of the options you can use.
S
Scott Barron 已提交
915
      def add_column(table_name, column_name, type, options = {})
916 917
        add_column_sql = "ALTER TABLE #{quote_table_name(table_name)} ADD COLUMN #{quote_column_name(column_name)} #{type_to_sql(type, options[:limit], options[:precision], options[:scale])}"
        add_column_options!(add_column_sql, options)
918

919
        execute add_column_sql
S
Scott Barron 已提交
920
      end
D
Initial  
David Heinemeier Hansson 已提交
921

922 923
      # Changes the column of a table.
      def change_column(table_name, column_name, type, options = {})
924 925
        quoted_table_name = quote_table_name(table_name)

926
        execute "ALTER TABLE #{quoted_table_name} ALTER COLUMN #{quote_column_name(column_name)} TYPE #{type_to_sql(type, options[:limit], options[:precision], options[:scale])}"
927

928 929
        change_column_default(table_name, column_name, options[:default]) if options_include_default?(options)
        change_column_null(table_name, column_name, options[:null], options[:default]) if options.key?(:null)
930
      end
931

932 933
      # Changes the default value of a table column.
      def change_column_default(table_name, column_name, default)
934
        execute "ALTER TABLE #{quote_table_name(table_name)} ALTER COLUMN #{quote_column_name(column_name)} SET DEFAULT #{quote(default)}"
935
      end
936

937 938
      def change_column_null(table_name, column_name, null, default = nil)
        unless null || default.nil?
939
          execute("UPDATE #{quote_table_name(table_name)} SET #{quote_column_name(column_name)}=#{quote(default)} WHERE #{quote_column_name(column_name)} IS NULL")
940
        end
941
        execute("ALTER TABLE #{quote_table_name(table_name)} ALTER #{quote_column_name(column_name)} #{null ? 'DROP' : 'SET'} NOT NULL")
942 943
      end

944 945
      # Renames a column in a table.
      def rename_column(table_name, column_name, new_column_name)
946
        execute "ALTER TABLE #{quote_table_name(table_name)} RENAME COLUMN #{quote_column_name(column_name)} TO #{quote_column_name(new_column_name)}"
947
      end
948

949 950 951 952
      def remove_index!(table_name, index_name) #:nodoc:
        execute "DROP INDEX #{quote_table_name(index_name)}"
      end

953 954 955 956
      def rename_index(table_name, old_name, new_name)
        execute "ALTER INDEX #{quote_column_name(old_name)} RENAME TO #{quote_table_name(new_name)}"
      end

957 958
      def index_name_length
        63
959
      end
960

961 962
      # Maps logical Rails types to PostgreSQL-specific data types.
      def type_to_sql(type, limit = nil, precision = nil, scale = nil)
963
        return super unless type.to_s == 'integer'
964
        return 'integer' unless limit
965

966
        case limit
967 968 969
          when 1, 2; 'smallint'
          when 3, 4; 'integer'
          when 5..8; 'bigint'
970
          else raise(ActiveRecordError, "No integer type has byte size #{limit}. Use a numeric with precision 0 instead.")
971 972
        end
      end
973

974
      # Returns a SELECT DISTINCT clause for a given set of columns and a given ORDER BY clause.
975 976 977
      #
      # PostgreSQL requires the ORDER BY columns in the select list for distinct queries, and
      # requires that the ORDER BY include the distinct column.
978
      #
979
      #   distinct("posts.id", "posts.created_at desc")
980 981
      def distinct(columns, orders) #:nodoc:
        return "DISTINCT #{columns}" if orders.empty?
982

983 984
        # Construct a clean list of column names from the ORDER BY clause, removing
        # any ASC/DESC modifiers
985
        order_columns = orders.collect { |s| s.gsub(/\s+(ASC|DESC)\s*/i, '') }
986
        order_columns.delete_if { |c| c.blank? }
987
        order_columns = order_columns.zip((0...order_columns.size).to_a).map { |s,i| "#{s} AS alias_#{i}" }
988

989
        "DISTINCT #{columns}, #{order_columns * ', '}"
990
      end
991

992
      module Utils
993 994
        extend self

995 996 997 998 999 1000 1001 1002 1003 1004
        # Returns an array of <tt>[schema_name, table_name]</tt> extracted from +name+.
        # +schema_name+ is nil if not specified in +name+.
        # +schema_name+ and +table_name+ exclude surrounding quotes (regardless of whether provided in +name+)
        # +name+ supports the range of schema/table references understood by PostgreSQL, for example:
        #
        # * <tt>table_name</tt>
        # * <tt>"table.name"</tt>
        # * <tt>schema_name.table_name</tt>
        # * <tt>schema_name."table.name"</tt>
        # * <tt>"schema.name"."table name"</tt>
1005
        def extract_schema_and_table(name)
1006 1007 1008 1009 1010
          table, schema = name.scan(/[^".\s]+|"[^"]*"/)[0..1].collect{|m| m.gsub(/(^"|"$)/,'') }.reverse
          [schema, table]
        end
      end

1011
      protected
1012
        # Returns the version of the connected PostgreSQL server.
1013
        def postgresql_version
1014
          @connection.server_version
1015 1016
        end

1017 1018 1019
        def translate_exception(exception, message)
          case exception.message
          when /duplicate key value violates unique constraint/
1020
            RecordNotUnique.new(message, exception)
1021
          when /violates foreign key constraint/
1022
            InvalidForeignKey.new(message, exception)
1023 1024 1025 1026 1027
          else
            super
          end
        end

D
Initial  
David Heinemeier Hansson 已提交
1028
      private
1029 1030
        def exec_no_cache(sql, binds)
          @connection.async_exec(sql)
1031
        end
1032

1033 1034
        def exec_cache(sql, binds)
          unless @statements.key? sql
1035
            nextkey = @statements.next_key
1036 1037 1038
            @connection.prepare nextkey, sql
            @statements[sql] = nextkey
          end
1039

1040 1041 1042 1043 1044 1045 1046 1047 1048 1049
          key = @statements[sql]

          # Clear the queue
          @connection.get_last_result
          @connection.send_query_prepared(key, binds.map { |col, val|
            type_cast(val, col)
          })
          @connection.block
          @connection.get_last_result
        end
1050

P
Pratik Naik 已提交
1051
        # The internal PostgreSQL identifier of the money data type.
1052
        MONEY_COLUMN_TYPE_OID = 790 #:nodoc:
1053 1054
        # The internal PostgreSQL identifier of the BYTEA data type.
        BYTEA_COLUMN_TYPE_OID = 17 #:nodoc:
1055 1056 1057 1058 1059 1060 1061 1062 1063

        # Connects to a PostgreSQL server and sets up the adapter depending on the
        # connected server's characteristics.
        def connect
          @connection = PGconn.connect(*@connection_parameters)

          # Money type has a fixed precision of 10 in PostgreSQL 8.2 and below, and as of
          # PostgreSQL 8.3 it has a fixed precision of 19. PostgreSQLColumn.extract_precision
          # should know about this but can't detect it there, so deal with it here.
1064 1065
          PostgreSQLColumn.money_precision = (postgresql_version >= 80300) ? 19 : 10

1066 1067 1068
          configure_connection
        end

1069
        # Configures the encoding, verbosity, schema search path, and time zone of the connection.
1070
        # This is called by #connect and should not be called manually.
1071 1072
        def configure_connection
          if @config[:encoding]
1073
            @connection.set_client_encoding(@config[:encoding])
1074
          end
1075 1076
          self.client_min_messages = @config[:min_messages] if @config[:min_messages]
          self.schema_search_path = @config[:schema_search_path] || @config[:schema_order]
1077 1078 1079 1080

          # Use standard-conforming strings if available so we don't have to do the E'...' dance.
          set_standard_conforming_strings

1081
          # If using Active Record's time zone support configure the connection to return
1082
          # TIMESTAMP WITH ZONE types in UTC.
1083
          if ActiveRecord::Base.default_timezone == :utc
1084
            execute("SET time zone 'UTC'", 'SCHEMA')
1085
          elsif @local_tz
1086
            execute("SET time zone '#{@local_tz}'", 'SCHEMA')
1087
          end
1088 1089
        end

1090
        # Returns the current ID of a table's sequence.
1091 1092 1093
        def last_insert_id(sequence_name) #:nodoc:
          r = exec_query("SELECT currval($1)", 'SQL', [[nil, sequence_name]])
          Integer(r.rows.first.first)
D
Initial  
David Heinemeier Hansson 已提交
1094 1095
        end

1096
        # Executes a SELECT query and returns the results, performing any data type
1097
        # conversions that are required to be performed here instead of in PostgreSQLColumn.
1098
        def select(sql, name = nil, binds = [])
A
Aaron Patterson 已提交
1099
          exec_query(sql, name, binds).to_a
1100 1101 1102
        end

        def select_raw(sql, name = nil)
1103
          res = execute(sql, name)
1104
          results = result_as_array(res)
1105
          fields = res.fields
1106
          res.clear
1107
          return fields, results
M
Marcel Molina 已提交
1108 1109
        end

1110
        # Returns the list of a table's column names, data types, and default values.
1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127
        #
        # The underlying query is roughly:
        #  SELECT column.name, column.type, default.value
        #    FROM column LEFT JOIN default
        #      ON column.table_id = default.table_id
        #     AND column.num = default.column_num
        #   WHERE column.table_id = get_table_id('table_name')
        #     AND column.num > 0
        #     AND NOT column.is_dropped
        #   ORDER BY column.num
        #
        # If the table name is not prefixed with a schema, the database will
        # take the first match from the schema search path.
        #
        # Query implementation notes:
        #  - format_type includes the column size constraint, e.g. varchar(50)
        #  - ::regclass is a function that gives the id for a table name
1128
        def column_definitions(table_name) #:nodoc:
1129
          exec_query(<<-end_sql, 'SCHEMA').rows
1130
            SELECT a.attname, format_type(a.atttypid, a.atttypmod), d.adsrc, a.attnotnull
1131 1132
              FROM pg_attribute a LEFT JOIN pg_attrdef d
                ON a.attrelid = d.adrelid AND a.attnum = d.adnum
1133
             WHERE a.attrelid = '#{quote_table_name(table_name)}'::regclass
1134 1135 1136
               AND a.attnum > 0 AND NOT a.attisdropped
             ORDER BY a.attnum
          end_sql
D
Initial  
David Heinemeier Hansson 已提交
1137
        end
1138 1139

        def extract_pg_identifier_from_name(name)
1140
          match_data = name.start_with?('"') ? name.match(/\"([^\"]+)\"/) : name.match(/([^\.]+)/)
1141 1142

          if match_data
1143 1144
            rest = name[match_data[0].length, name.length]
            rest = rest[1, rest.length] if rest.start_with? "."
1145 1146 1147
            [match_data[1], (rest.length > 0 ? rest : nil)]
          end
        end
1148

1149 1150 1151 1152 1153
        def extract_table_ref_from_insert_sql(sql)
          sql[/into\s+([^\(]*).*values\s*\(/i]
          $1.strip if $1
        end

1154 1155 1156
        def table_definition
          TableDefinition.new(self)
        end
D
Initial  
David Heinemeier Hansson 已提交
1157 1158 1159
    end
  end
end