postgresql_adapter.rb 41.0 KB
Newer Older
D
Initial  
David Heinemeier Hansson 已提交
1
require 'active_record/connection_adapters/abstract_adapter'
2
require 'active_support/core_ext/object/blank'
3
require 'active_record/connection_adapters/statement_pool'
4 5 6

# Make sure we're using pg high enough for PGResult#values
gem 'pg', '~> 0.11'
7
require 'pg'
D
Initial  
David Heinemeier Hansson 已提交
8 9 10 11 12

module ActiveRecord
  class Base
    # Establishes a connection to the database that's used by all Active Record objects
    def self.postgresql_connection(config) # :nodoc:
13
      config = config.symbolize_keys
D
Initial  
David Heinemeier Hansson 已提交
14
      host     = config[:host]
15
      port     = config[:port] || 5432
16 17
      username = config[:username].to_s if config[:username]
      password = config[:password].to_s if config[:password]
D
Initial  
David Heinemeier Hansson 已提交
18

19
      if config.key?(:database)
D
Initial  
David Heinemeier Hansson 已提交
20 21 22 23 24
        database = config[:database]
      else
        raise ArgumentError, "No database specified. Missing argument: database."
      end

25
      # The postgres drivers don't allow the creation of an unconnected PGconn object,
26 27 28 29
      # so just pass a nil connection object for the time being.
      ConnectionAdapters::PostgreSQLAdapter.new(nil, logger, [host, port, nil, nil, database, username, password], config)
    end
  end
30

31 32 33 34 35 36 37
  module ConnectionAdapters
    # PostgreSQL-specific extensions to column definitions in a table.
    class PostgreSQLColumn < Column #:nodoc:
      # Instantiates a new PostgreSQL column definition in a table.
      def initialize(name, default, sql_type = nil, null = true)
        super(name, self.class.extract_value_from_default(default), sql_type, null)
      end
38

39 40 41
      # :stopdoc:
      class << self
        attr_accessor :money_precision
42 43 44 45 46 47 48 49 50 51
        def string_to_time(string)
          return string unless String === string

          case string
          when 'infinity'  then 1.0 / 0.0
          when '-infinity' then -1.0 / 0.0
          else
            super
          end
        end
52 53 54
      end
      # :startdoc:

55
      private
56
        def extract_limit(sql_type)
57 58 59 60 61
          case sql_type
          when /^bigint/i;    8
          when /^smallint/i;  2
          else super
          end
62 63
        end

64 65 66 67 68
        # Extracts the scale from PostgreSQL-specific data types.
        def extract_scale(sql_type)
          # Money type has a fixed scale of 2.
          sql_type =~ /^money/ ? 2 : super
        end
69

70 71
        # Extracts the precision from PostgreSQL-specific data types.
        def extract_precision(sql_type)
72 73 74 75 76
          if sql_type == 'money'
            self.class.money_precision
          else
            super
          end
77
        end
78

79 80 81 82 83 84 85
        # Maps PostgreSQL-specific data types to logical Rails types.
        def simplified_type(field_type)
          case field_type
            # Numeric and monetary types
            when /^(?:real|double precision)$/
              :float
            # Monetary types
86
            when 'money'
87 88 89 90 91
              :decimal
            # Character types
            when /^(?:character varying|bpchar)(?:\(\d+\))?$/
              :string
            # Binary data types
92
            when 'bytea'
93 94 95 96
              :binary
            # Date/time types
            when /^timestamp with(?:out)? time zone$/
              :datetime
97
            when 'interval'
98 99 100 101 102 103 104 105 106 107 108
              :string
            # Geometric types
            when /^(?:point|line|lseg|box|"?path"?|polygon|circle)$/
              :string
            # Network address types
            when /^(?:cidr|inet|macaddr)$/
              :string
            # Bit strings
            when /^bit(?: varying)?(?:\(\d+\))?$/
              :string
            # XML type
109
            when 'xml'
110
              :xml
111 112 113
            # tsvector type
            when 'tsvector'
              :tsvector
114 115
            # Arrays
            when /^\D+\[\]$/
116
              :string
117
            # Object identifier types
118
            when 'oid'
119
              :integer
120
            # UUID type
121
            when 'uuid'
122 123 124 125
              :string
            # Small and big integer types
            when /^(?:small|big)int$/
              :integer
126 127 128 129 130
            # Pass through all types that are not specific to PostgreSQL.
            else
              super
          end
        end
131

132 133 134
        # Extracts the value from a PostgreSQL column default definition.
        def self.extract_value_from_default(default)
          case default
135 136 137 138 139 140 141 142
            # This is a performance optimization for Ruby 1.9.2 in development.
            # If the value is nil, we return nil straight away without checking
            # the regular expressions. If we check each regular expression,
            # Regexp#=== will call NilClass#to_str, which will trigger
            # method_missing (defined by whiny nil in ActiveSupport) which
            # makes this method very very slow.
            when NilClass
              nil
143
            # Numeric types
144 145
            when /\A\(?(-?\d+(\.\d*)?\)?)\z/
              $1
146
            # Character types
147
            when /\A'(.*)'::(?:character varying|bpchar|text)\z/m
148
              $1
149 150 151
            # Character types (8.1 formatting)
            when /\AE'(.*)'::(?:character varying|bpchar|text)\z/m
              $1.gsub(/\\(\d\d\d)/) { $1.oct.chr }
152
            # Binary data types
153
            when /\A'(.*)'::bytea\z/m
154 155
              $1
            # Date/time types
156
            when /\A'(.+)'::(?:time(?:stamp)? with(?:out)? time zone|date)\z/
157
              $1
158
            when /\A'(.*)'::interval\z/
159 160
              $1
            # Boolean type
161
            when 'true'
162
              true
163
            when 'false'
164 165
              false
            # Geometric types
166
            when /\A'(.*)'::(?:point|line|lseg|box|"?path"?|polygon|circle)\z/
167 168
              $1
            # Network address types
169
            when /\A'(.*)'::(?:cidr|inet|macaddr)\z/
170 171
              $1
            # Bit string types
172
            when /\AB'(.*)'::"?bit(?: varying)?"?\z/
173 174
              $1
            # XML type
175
            when /\A'(.*)'::xml\z/m
176 177
              $1
            # Arrays
178
            when /\A'(.*)'::"?\D+"?\[\]\z/
179 180
              $1
            # Object identifier types
181
            when /\A-?\d+\z/
182 183 184
              $1
            else
              # Anything else is blank, some user type, or some function
185
              # and we can't know the value of that, so return nil.
186 187 188
              nil
          end
        end
D
Initial  
David Heinemeier Hansson 已提交
189 190
    end

191 192
    # The PostgreSQL adapter works both with the native C (http://ruby.scripting.ca/postgres/) and the pure
    # Ruby (available both as gem and from http://rubyforge.org/frs/?group_id=234&release_id=1944) drivers.
193 194 195
    #
    # Options:
    #
P
Pratik Naik 已提交
196 197 198 199 200
    # * <tt>:host</tt> - Defaults to "localhost".
    # * <tt>:port</tt> - Defaults to 5432.
    # * <tt>:username</tt> - Defaults to nothing.
    # * <tt>:password</tt> - Defaults to nothing.
    # * <tt>:database</tt> - The name of the database. No default, must be provided.
201
    # * <tt>:schema_search_path</tt> - An optional schema search path for the connection given
202
    #   as a string of comma-separated schema names. This is backward-compatible with the <tt>:schema_order</tt> option.
203
    # * <tt>:encoding</tt> - An optional client encoding that is used in a <tt>SET client_encoding TO
204
    #   <encoding></tt> call on the connection.
205
    # * <tt>:min_messages</tt> - An optional client min messages that is used in a
206
    #   <tt>SET client_min_messages TO <min_messages></tt> call on the connection.
207
    class PostgreSQLAdapter < AbstractAdapter
208 209 210 211 212
      class TableDefinition < ActiveRecord::ConnectionAdapters::TableDefinition
        def xml(*args)
          options = args.extract_options!
          column(args[0], 'xml', options)
        end
213 214 215 216 217

        def tsvector(*args)
          options = args.extract_options!
          column(args[0], 'tsvector', options)
        end
218 219
      end

220
      ADAPTER_NAME = 'PostgreSQL'
221 222

      NATIVE_DATABASE_TYPES = {
223
        :primary_key => "serial primary key",
224 225 226 227 228 229 230 231 232 233
        :string      => { :name => "character varying", :limit => 255 },
        :text        => { :name => "text" },
        :integer     => { :name => "integer" },
        :float       => { :name => "float" },
        :decimal     => { :name => "decimal" },
        :datetime    => { :name => "timestamp" },
        :timestamp   => { :name => "timestamp" },
        :time        => { :name => "time" },
        :date        => { :name => "date" },
        :binary      => { :name => "bytea" },
234
        :boolean     => { :name => "boolean" },
235 236
        :xml         => { :name => "xml" },
        :tsvector    => { :name => "tsvector" }
237 238
      }

239
      # Returns 'PostgreSQL' as adapter name for identification purposes.
240
      def adapter_name
241
        ADAPTER_NAME
242 243
      end

244 245
      # Returns +true+, since this connection adapter supports prepared statement
      # caching.
246 247 248 249
      def supports_statement_cache?
        true
      end

250 251 252 253
      class StatementPool < ConnectionAdapters::StatementPool
        def initialize(connection, max)
          super
          @counter = 0
254
          @cache   = Hash.new { |h,pid| h[pid] = {} }
255 256
        end

257 258 259 260
        def each(&block); cache.each(&block); end
        def key?(key);    cache.key?(key); end
        def [](key);      cache[key]; end
        def length;       cache.length; end
261 262 263 264 265 266

        def next_key
          "a#{@counter + 1}"
        end

        def []=(sql, key)
267 268
          while @max <= cache.size
            dealloc(cache.shift.last)
269 270
          end
          @counter += 1
271
          cache[sql] = key
272 273 274
        end

        def clear
275
          cache.each_value do |stmt_key|
276 277
            dealloc stmt_key
          end
278
          cache.clear
279 280
        end

281 282 283 284 285
        def delete(sql_key)
          dealloc cache[sql_key]
          cache.delete sql_key
        end

286
        private
287 288 289 290
        def cache
          @cache[$$]
        end

291
        def dealloc(key)
292 293 294 295 296 297 298
          @connection.query "DEALLOCATE #{key}" if connection_active?
        end

        def connection_active?
          @connection.status == PGconn::CONNECTION_OK
        rescue PGError
          false
299 300 301
        end
      end

302 303
      # Initializes and connects a PostgreSQL adapter.
      def initialize(connection, logger, connection_parameters, config)
304
        super(connection, logger)
305
        @connection_parameters, @config = connection_parameters, config
306

307 308
        # @local_tz is initialized as nil to avoid warnings when connect tries to use it
        @local_tz = nil
309 310
        @table_alias_length = nil

311
        connect
312 313
        @statements = StatementPool.new @connection,
                                        config.fetch(:statement_limit) { 1000 }
314 315 316 317 318

        if postgresql_version < 80200
          raise "Your version of PostgreSQL (#{postgresql_version}) is too old, please upgrade!"
        end

319
        @local_tz = execute('SHOW TIME ZONE', 'SCHEMA').first["TimeZone"]
320 321
      end

322 323 324 325
      def self.visitor_for(pool) # :nodoc:
        Arel::Visitors::PostgreSQL.new(pool)
      end

X
Xavier Noria 已提交
326
      # Clears the prepared statements cache.
327 328 329 330
      def clear_cache!
        @statements.clear
      end

331 332
      # Is this connection alive and ready for queries?
      def active?
333 334
        @connection.status == PGconn::CONNECTION_OK
      rescue PGError
335
        false
336 337 338 339
      end

      # Close then reopen the connection.
      def reconnect!
340 341 342
        clear_cache!
        @connection.reset
        configure_connection
343
      end
344

345 346 347 348 349
      def reset!
        clear_cache!
        super
      end

350 351
      # Disconnects from the database if already connected. Otherwise, this
      # method does nothing.
352
      def disconnect!
353
        clear_cache!
354 355
        @connection.close rescue nil
      end
356

357
      def native_database_types #:nodoc:
358
        NATIVE_DATABASE_TYPES
359
      end
360

361
      # Returns true, since this connection adapter supports migrations.
362 363
      def supports_migrations?
        true
364 365
      end

366
      # Does PostgreSQL support finding primary key on non-Active Record tables?
367 368 369 370
      def supports_primary_key? #:nodoc:
        true
      end

371 372 373
      # Enable standard-conforming strings if available.
      def set_standard_conforming_strings
        old, self.client_min_messages = client_min_messages, 'panic'
374
        execute('SET standard_conforming_strings = on', 'SCHEMA') rescue nil
375 376
      ensure
        self.client_min_messages = old
377 378
      end

379
      def supports_insert_with_returning?
380
        true
381 382
      end

383 384 385
      def supports_ddl_transactions?
        true
      end
386

387
      # Returns true, since this connection adapter supports savepoints.
388 389 390
      def supports_savepoints?
        true
      end
391

392
      # Returns the configured supported identifier length supported by PostgreSQL
393
      def table_alias_length
394
        @table_alias_length ||= query('SHOW max_identifier_length')[0][0].to_i
395
      end
396

397 398
      # QUOTING ==================================================

399
      # Escapes binary strings for bytea input to the database.
400 401
      def escape_bytea(value)
        @connection.escape_bytea(value) if value
402 403 404 405 406
      end

      # Unescapes bytea output from a database to the binary string it represents.
      # NOTE: This is NOT an inverse of escape_bytea! This is only to be used
      #       on escaped binary output from database drive.
407 408
      def unescape_bytea(value)
        @connection.unescape_bytea(value) if value
409 410
      end

411 412
      # Quotes PostgreSQL-specific data types for SQL input.
      def quote(value, column = nil) #:nodoc:
413 414
        return super unless column

A
Aaron Patterson 已提交
415
        case value
416 417 418
        when Float
          return super unless value.infinite? && column.type == :datetime
          "'#{value.to_s.downcase}'"
A
Aaron Patterson 已提交
419 420
        when Numeric
          return super unless column.sql_type == 'money'
421
          # Not truly string input, so doesn't require (or allow) escape string syntax.
422
          "'#{value}'"
A
Aaron Patterson 已提交
423 424 425 426 427 428 429 430 431 432 433
        when String
          case column.sql_type
          when 'bytea' then "'#{escape_bytea(value)}'"
          when 'xml'   then "xml '#{quote_string(value)}'"
          when /^bit/
            case value
            when /^[01]*$/      then "B'#{value}'" # Bit-string notation
            when /^[0-9A-F]*$/i then "X'#{value}'" # Hexadecimal notation
            end
          else
            super
434
          end
435 436 437 438 439
        else
          super
        end
      end

440 441 442 443 444 445
      def type_cast(value, column)
        return super unless column

        case value
        when String
          return super unless 'bytea' == column.sql_type
446
          { :value => value, :format => 1 }
447 448 449 450 451
        else
          super
        end
      end

452 453 454
      # Quotes strings for use in SQL input.
      def quote_string(s) #:nodoc:
        @connection.escape(s)
455 456
      end

457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475
      # Checks the following cases:
      #
      # - table_name
      # - "table.name"
      # - schema_name.table_name
      # - schema_name."table.name"
      # - "schema.name".table_name
      # - "schema.name"."table.name"
      def quote_table_name(name)
        schema, name_part = extract_pg_identifier_from_name(name.to_s)

        unless name_part
          quote_column_name(schema)
        else
          table_name, name_part = extract_pg_identifier_from_name(name_part)
          "#{quote_column_name(schema)}.#{quote_column_name(table_name)}"
        end
      end

476 477
      # Quotes column names for use in SQL queries.
      def quote_column_name(name) #:nodoc:
478
        PGconn.quote_ident(name.to_s)
479 480
      end

481 482 483
      # Quote date/time values for use in SQL input. Includes microseconds
      # if the value is a Time responding to usec.
      def quoted_date(value) #:nodoc:
484 485 486 487 488
        if value.acts_like?(:time) && value.respond_to?(:usec)
          "#{super}.#{sprintf("%06d", value.usec)}"
        else
          super
        end
489 490
      end

491 492
      # Set the authorized user for this session
      def session_auth=(user)
493
        clear_cache!
A
Aaron Patterson 已提交
494
        exec_query "SET SESSION AUTHORIZATION #{user}"
495 496
      end

497 498
      # REFERENTIAL INTEGRITY ====================================

499
      def supports_disable_referential_integrity? #:nodoc:
500
        true
501 502
      end

503
      def disable_referential_integrity #:nodoc:
504
        if supports_disable_referential_integrity? then
505 506
          execute(tables.collect { |name| "ALTER TABLE #{quote_table_name(name)} DISABLE TRIGGER ALL" }.join(";"))
        end
507 508
        yield
      ensure
509
        if supports_disable_referential_integrity? then
510 511
          execute(tables.collect { |name| "ALTER TABLE #{quote_table_name(name)} ENABLE TRIGGER ALL" }.join(";"))
        end
512
      end
513 514 515

      # DATABASE STATEMENTS ======================================

516 517 518 519 520 521
      # Executes a SELECT query and returns an array of rows. Each row is an
      # array of field values.
      def select_rows(sql, name = nil)
        select_raw(sql, name).last
      end

522
      # Executes an INSERT query and returns the new record's ID
523
      def insert_sql(sql, name = nil, pk = nil, id_value = nil, sequence_name = nil)
524 525 526 527 528
        unless pk
          # Extract the table from the insert sql. Yuck.
          table_ref = extract_table_ref_from_insert_sql(sql)
          pk = primary_key(table_ref) if table_ref
        end
529

530
        if pk
531 532 533
          select_value("#{sql} RETURNING #{quote_column_name(pk)}")
        else
          super
534
        end
535
      end
536
      alias :create :insert
537

538 539
      # create a 2D array representing the result set
      def result_as_array(res) #:nodoc:
540
        # check if we have any binary column and if they need escaping
541 542
        ftypes = Array.new(res.nfields) do |i|
          [i, res.ftype(i)]
543 544
        end

545 546 547 548 549 550
        rows = res.values
        return rows unless ftypes.any? { |_, x|
          x == BYTEA_COLUMN_TYPE_OID || x == MONEY_COLUMN_TYPE_OID
        }

        typehash = ftypes.group_by { |_, type| type }
551 552
        binaries = typehash[BYTEA_COLUMN_TYPE_OID] || []
        monies   = typehash[MONEY_COLUMN_TYPE_OID] || []
553 554 555

        rows.each do |row|
          # unescape string passed BYTEA field (OID == 17)
556 557
          binaries.each do |index, _|
            row[index] = unescape_bytea(row[index])
558 559 560 561 562 563
          end

          # If this is a money type column and there are any currency symbols,
          # then strip them off. Indeed it would be prettier to do this in
          # PostgreSQLColumn.string_to_decimal but would break form input
          # fields that call value_before_type_cast.
564
          monies.each do |index, _|
565 566 567 568 569 570 571 572 573 574
            data = row[index]
            # Because money output is formatted according to the locale, there are two
            # cases to consider (note the decimal separators):
            #  (1) $12,345,678.12
            #  (2) $12.345.678,12
            case data
            when /^-?\D+[\d,]+\.\d{2}$/  # (1)
              data.gsub!(/[^-\d.]/, '')
            when /^-?\D+[\d.]+,\d{2}$/  # (2)
              data.gsub!(/[^-\d,]/, '').sub!(/,/, '.')
575
            end
576 577 578 579 580 581
          end
        end
      end


      # Queries the database and returns the results in an Array-like object
582
      def query(sql, name = nil) #:nodoc:
583
        log(sql, name) do
584
          result_as_array @connection.async_exec(sql)
585
        end
586 587
      end

588
      # Executes an SQL statement, returning a PGresult object on success
589 590
      # or raising a PGError exception otherwise.
      def execute(sql, name = nil)
591
        log(sql, name) do
592
          @connection.async_exec(sql)
593
        end
594 595
      end

596 597
      def substitute_at(column, index)
        Arel.sql("$#{index + 1}")
598 599
      end

A
Aaron Patterson 已提交
600
      def exec_query(sql, name = 'SQL', binds = [])
601
        log(sql, name, binds) do
602 603
          result = binds.empty? ? exec_no_cache(sql, binds) :
                                  exec_cache(sql, binds)
604

605 606 607
          ret = ActiveRecord::Result.new(result.fields, result_as_array(result))
          result.clear
          return ret
608 609 610
        end
      end

611 612 613 614 615 616 617 618 619
      def exec_delete(sql, name = 'SQL', binds = [])
        log(sql, name, binds) do
          result = binds.empty? ? exec_no_cache(sql, binds) :
                                  exec_cache(sql, binds)
          affected = result.cmd_tuples
          result.clear
          affected
        end
      end
620
      alias :exec_update :exec_delete
621

622 623
      def sql_for_insert(sql, pk, id_value, sequence_name, binds)
        unless pk
624 625 626
          # Extract the table from the insert sql. Yuck.
          table_ref = extract_table_ref_from_insert_sql(sql)
          pk = primary_key(table_ref) if table_ref
627 628 629 630 631 632 633
        end

        sql = "#{sql} RETURNING #{quote_column_name(pk)}" if pk

        [sql, binds]
      end

634
      # Executes an UPDATE query and returns the number of affected tuples.
635
      def update_sql(sql, name = nil)
636
        super.cmd_tuples
637 638
      end

639 640
      # Begins a transaction.
      def begin_db_transaction
641 642 643
        execute "BEGIN"
      end

644 645
      # Commits a transaction.
      def commit_db_transaction
646 647
        execute "COMMIT"
      end
648

649 650
      # Aborts a transaction.
      def rollback_db_transaction
651 652
        execute "ROLLBACK"
      end
653

654 655
      def outside_transaction?
        @connection.transaction_status == PGconn::PQTRANS_IDLE
656
      end
657

J
Jonathan Viney 已提交
658 659 660 661 662 663 664 665
      def create_savepoint
        execute("SAVEPOINT #{current_savepoint_name}")
      end

      def rollback_to_savepoint
        execute("ROLLBACK TO SAVEPOINT #{current_savepoint_name}")
      end

666
      def release_savepoint
J
Jonathan Viney 已提交
667 668
        execute("RELEASE SAVEPOINT #{current_savepoint_name}")
      end
669

670 671
      # SCHEMA STATEMENTS ========================================

672 673 674
      # Drops the database specified on the +name+ attribute
      # and creates it again using the provided +options+.
      def recreate_database(name, options = {}) #:nodoc:
675
        drop_database(name)
676
        create_database(name, options)
677 678
      end

679
      # Create a new PostgreSQL database. Options include <tt>:owner</tt>, <tt>:template</tt>,
680 681
      # <tt>:encoding</tt>, <tt>:tablespace</tt>, and <tt>:connection_limit</tt> (note that MySQL uses
      # <tt>:charset</tt> while PostgreSQL uses <tt>:encoding</tt>).
682 683 684 685 686 687 688 689 690 691
      #
      # Example:
      #   create_database config[:database], config
      #   create_database 'foo_development', :encoding => 'unicode'
      def create_database(name, options = {})
        options = options.reverse_merge(:encoding => "utf8")

        option_string = options.symbolize_keys.sum do |key, value|
          case key
          when :owner
692
            " OWNER = \"#{value}\""
693
          when :template
694
            " TEMPLATE = \"#{value}\""
695 696 697
          when :encoding
            " ENCODING = '#{value}'"
          when :tablespace
698
            " TABLESPACE = \"#{value}\""
699 700 701 702 703 704 705
          when :connection_limit
            " CONNECTION LIMIT = #{value}"
          else
            ""
          end
        end

706
        execute "CREATE DATABASE #{quote_table_name(name)}#{option_string}"
707 708
      end

709
      # Drops a PostgreSQL database.
710 711 712 713
      #
      # Example:
      #   drop_database 'matt_development'
      def drop_database(name) #:nodoc:
714
        execute "DROP DATABASE IF EXISTS #{quote_table_name(name)}"
715 716
      end

717 718
      # Returns the list of all tables in the schema search path or a specified schema.
      def tables(name = nil)
719
        query(<<-SQL, 'SCHEMA').map { |row| row[0] }
720
          SELECT tablename
721 722 723 724 725
          FROM pg_tables
          WHERE schemaname = ANY (current_schemas(false))
        SQL
      end

726
      # Returns true if table exists.
727 728
      # If the schema is not specified as part of +name+ then it will only find tables within
      # the current schema search path (regardless of permissions to access tables in other schemas)
729
      def table_exists?(name)
730
        schema, table = Utils.extract_schema_and_table(name.to_s)
731
        return false unless table
732

733 734
        binds = [[nil, table]]
        binds << [nil, schema] if schema
735 736

        exec_query(<<-SQL, 'SCHEMA', binds).rows.first[0].to_i > 0
737
            SELECT COUNT(*)
A
Aaron Patterson 已提交
738 739 740 741 742
            FROM pg_class c
            LEFT JOIN pg_namespace n ON n.oid = c.relnamespace
            WHERE c.relkind in ('v','r')
            AND c.relname = $1
            AND n.nspname = #{schema ? '$2' : 'ANY (current_schemas(false))'}
743 744 745
        SQL
      end

746 747 748 749 750 751 752 753
      # Returns true if schema exists.
      def schema_exists?(name)
        exec_query(<<-SQL, 'SCHEMA', [[nil, name]]).rows.first[0].to_i > 0
          SELECT COUNT(*)
          FROM pg_namespace
          WHERE nspname = $1
        SQL
      end
754

755
      # Returns an array of indexes for the given table.
756
      def indexes(table_name, name = nil)
757 758
         schemas = schema_search_path.split(/,/).map { |p| quote(p) }.join(',')
         result = query(<<-SQL, name)
759
           SELECT distinct i.relname, d.indisunique, d.indkey, t.oid
760 761 762
           FROM pg_class t
           INNER JOIN pg_index d ON t.oid = d.indrelid
           INNER JOIN pg_class i ON d.indexrelid = i.oid
763 764 765
           WHERE i.relkind = 'i'
             AND d.indisprimary = 'f'
             AND t.relname = '#{table_name}'
766
             AND i.relnamespace IN (SELECT oid FROM pg_namespace WHERE nspname IN (#{schemas}) )
767 768 769
          ORDER BY i.relname
        SQL

770

771
        result.map do |row|
772 773 774 775 776
          index_name = row[0]
          unique = row[1] == 't'
          indkey = row[2].split(" ")
          oid = row[3]

777 778
          columns = Hash[query(<<-SQL, "Columns for index #{row[0]} on #{table_name}")]
          SELECT a.attnum, a.attname
779 780 781 782 783
          FROM pg_attribute a
          WHERE a.attrelid = #{oid}
          AND a.attnum IN (#{indkey.join(",")})
          SQL

784 785
          column_names = columns.values_at(*indkey).compact
          column_names.empty? ? nil : IndexDefinition.new(table_name, index_name, unique, column_names)
786
        end.compact
787 788
      end

789 790
      # Returns the list of all column definitions for a table.
      def columns(table_name, name = nil)
791
        # Limit, precision, and scale are all handled by the superclass.
792 793
        column_definitions(table_name).collect do |column_name, type, default, notnull|
          PostgreSQLColumn.new(column_name, default, type, notnull == 'f')
D
Initial  
David Heinemeier Hansson 已提交
794 795 796
        end
      end

797 798 799 800 801
      # Returns the current database name.
      def current_database
        query('select current_database()')[0][0]
      end

802 803 804 805 806
      # Returns the current schema name.
      def current_schema
        query('SELECT current_schema', 'SCHEMA')[0][0]
      end

807 808 809 810 811 812 813 814
      # Returns the current database encoding format.
      def encoding
        query(<<-end_sql)[0][0]
          SELECT pg_encoding_to_char(pg_database.encoding) FROM pg_database
          WHERE pg_database.datname LIKE '#{current_database}'
        end_sql
      end

815 816 817 818 819 820
      # Sets the schema search path to a string of comma-separated schema names.
      # Names beginning with $ have to be quoted (e.g. $user => '$user').
      # See: http://www.postgresql.org/docs/current/static/ddl-schemas.html
      #
      # This should be not be called manually but set in database.yml.
      def schema_search_path=(schema_csv)
821 822
        if schema_csv
          execute "SET search_path TO #{schema_csv}"
823
          @schema_search_path = schema_csv
824
        end
D
Initial  
David Heinemeier Hansson 已提交
825 826
      end

827 828
      # Returns the active schema search path.
      def schema_search_path
829
        @schema_search_path ||= query('SHOW search_path')[0][0]
830
      end
831

832 833
      # Returns the current client message level.
      def client_min_messages
834
        query('SHOW client_min_messages', 'SCHEMA')[0][0]
835 836 837 838
      end

      # Set the client message level.
      def client_min_messages=(level)
839
        execute("SET client_min_messages TO '#{level}'", 'SCHEMA')
840 841 842 843
      end

      # Returns the sequence name for a table's primary key or some other specified key.
      def default_sequence_name(table_name, pk = nil) #:nodoc:
844 845 846 847 848 849 850 851 852 853
        serial_sequence(table_name, pk || 'id').split('.').last
      rescue ActiveRecord::StatementInvalid
        "#{table_name}_#{pk || 'id'}_seq"
      end

      def serial_sequence(table, column)
        result = exec_query(<<-eosql, 'SCHEMA', [[nil, table], [nil, column]])
          SELECT pg_get_serial_sequence($1, $2)
        eosql
        result.rows.first.first
854 855
      end

856 857
      # Resets the sequence of a table's primary key to the maximum value.
      def reset_pk_sequence!(table, pk = nil, sequence = nil) #:nodoc:
858 859
        unless pk and sequence
          default_pk, default_sequence = pk_and_sequence_for(table)
860

861 862 863
          pk ||= default_pk
          sequence ||= default_sequence
        end
864

865 866 867 868 869
        if @logger && pk && !sequence
          @logger.warn "#{table} has primary key #{pk} with no default sequence"
        end

        if pk && sequence
870
          quoted_sequence = quote_table_name(sequence)
G
Guillermo Iguaran 已提交
871

872 873 874
          select_value <<-end_sql, 'Reset sequence'
            SELECT setval('#{quoted_sequence}', (SELECT COALESCE(MAX(#{quote_column_name pk})+(SELECT increment_by FROM #{quoted_sequence}), (SELECT min_value FROM #{quoted_sequence})) FROM #{quote_table_name(table)}), false)
          end_sql
875 876 877
        end
      end

878 879
      # Returns a table's primary key and belonging sequence.
      def pk_and_sequence_for(table) #:nodoc:
880 881
        # First try looking for a sequence with a dependency on the
        # given table's primary key.
882
        result = exec_query(<<-end_sql, 'SCHEMA').rows.first
883
          SELECT attr.attname, ns.nspname, seq.relname
884
          FROM pg_class seq
A
Akira Matsuda 已提交
885
          INNER JOIN pg_depend dep ON seq.oid = dep.objid
886 887
          INNER JOIN pg_attribute attr ON attr.attrelid = dep.refobjid AND attr.attnum = dep.refobjsubid
          INNER JOIN pg_constraint cons ON attr.attrelid = cons.conrelid AND attr.attnum = cons.conkey[1]
888
          INNER JOIN pg_namespace ns ON seq.relnamespace = ns.oid
889 890 891
          WHERE seq.relkind  = 'S'
            AND cons.contype = 'p'
            AND dep.refobjid = '#{quote_table_name(table)}'::regclass
892
        end_sql
893

894
        # [primary_key, sequence]
895 896 897 898 899
        if result.second ==  'public' then
          sequence = result.last
        else
          sequence = result.second+'.'+result.last
        end
G
Guillermo Iguaran 已提交
900

901
        [result.first, sequence]
902 903
      rescue
        nil
904 905
      end

906 907
      # Returns just a table's primary key
      def primary_key(table)
908
        row = exec_query(<<-end_sql, 'SCHEMA', [[nil, table]]).rows.first
909
          SELECT DISTINCT(attr.attname)
910 911 912 913 914
          FROM pg_attribute attr
          INNER JOIN pg_depend dep ON attr.attrelid = dep.refobjid AND attr.attnum = dep.refobjsubid
          INNER JOIN pg_constraint cons ON attr.attrelid = cons.conrelid AND attr.attnum = cons.conkey[1]
          WHERE cons.contype = 'p'
            AND dep.refobjid = $1::regclass
915 916 917
        end_sql

        row && row.first
918 919
      end

920
      # Renames a table.
921 922 923
      #
      # Example:
      #   rename_table('octopuses', 'octopi')
924
      def rename_table(name, new_name)
925
        clear_cache!
926
        execute "ALTER TABLE #{quote_table_name(name)} RENAME TO #{quote_table_name(new_name)}"
927
      end
928

929 930
      # Adds a new column to the named table.
      # See TableDefinition#column for details of the options you can use.
S
Scott Barron 已提交
931
      def add_column(table_name, column_name, type, options = {})
932
        clear_cache!
933 934
        add_column_sql = "ALTER TABLE #{quote_table_name(table_name)} ADD COLUMN #{quote_column_name(column_name)} #{type_to_sql(type, options[:limit], options[:precision], options[:scale])}"
        add_column_options!(add_column_sql, options)
935

936
        execute add_column_sql
S
Scott Barron 已提交
937
      end
D
Initial  
David Heinemeier Hansson 已提交
938

939 940
      # Changes the column of a table.
      def change_column(table_name, column_name, type, options = {})
941
        clear_cache!
942 943
        quoted_table_name = quote_table_name(table_name)

944
        execute "ALTER TABLE #{quoted_table_name} ALTER COLUMN #{quote_column_name(column_name)} TYPE #{type_to_sql(type, options[:limit], options[:precision], options[:scale])}"
945

946 947
        change_column_default(table_name, column_name, options[:default]) if options_include_default?(options)
        change_column_null(table_name, column_name, options[:null], options[:default]) if options.key?(:null)
948
      end
949

950 951
      # Changes the default value of a table column.
      def change_column_default(table_name, column_name, default)
952
        clear_cache!
953
        execute "ALTER TABLE #{quote_table_name(table_name)} ALTER COLUMN #{quote_column_name(column_name)} SET DEFAULT #{quote(default)}"
954
      end
955

956
      def change_column_null(table_name, column_name, null, default = nil)
957
        clear_cache!
958
        unless null || default.nil?
959
          execute("UPDATE #{quote_table_name(table_name)} SET #{quote_column_name(column_name)}=#{quote(default)} WHERE #{quote_column_name(column_name)} IS NULL")
960
        end
961
        execute("ALTER TABLE #{quote_table_name(table_name)} ALTER #{quote_column_name(column_name)} #{null ? 'DROP' : 'SET'} NOT NULL")
962 963
      end

964 965
      # Renames a column in a table.
      def rename_column(table_name, column_name, new_column_name)
966
        clear_cache!
967
        execute "ALTER TABLE #{quote_table_name(table_name)} RENAME COLUMN #{quote_column_name(column_name)} TO #{quote_column_name(new_column_name)}"
968
      end
969

970 971 972 973
      def remove_index!(table_name, index_name) #:nodoc:
        execute "DROP INDEX #{quote_table_name(index_name)}"
      end

974 975 976 977
      def rename_index(table_name, old_name, new_name)
        execute "ALTER INDEX #{quote_column_name(old_name)} RENAME TO #{quote_table_name(new_name)}"
      end

978 979
      def index_name_length
        63
980
      end
981

982 983
      # Maps logical Rails types to PostgreSQL-specific data types.
      def type_to_sql(type, limit = nil, precision = nil, scale = nil)
984
        return super unless type.to_s == 'integer'
985
        return 'integer' unless limit
986

987
        case limit
988 989 990
          when 1, 2; 'smallint'
          when 3, 4; 'integer'
          when 5..8; 'bigint'
991
          else raise(ActiveRecordError, "No integer type has byte size #{limit}. Use a numeric with precision 0 instead.")
992 993
        end
      end
994

995
      # Returns a SELECT DISTINCT clause for a given set of columns and a given ORDER BY clause.
996 997 998
      #
      # PostgreSQL requires the ORDER BY columns in the select list for distinct queries, and
      # requires that the ORDER BY include the distinct column.
999
      #
1000
      #   distinct("posts.id", "posts.created_at desc")
1001 1002
      def distinct(columns, orders) #:nodoc:
        return "DISTINCT #{columns}" if orders.empty?
1003

1004 1005
        # Construct a clean list of column names from the ORDER BY clause, removing
        # any ASC/DESC modifiers
1006
        order_columns = orders.collect { |s| s.gsub(/\s+(ASC|DESC)\s*/i, '') }
1007
        order_columns.delete_if { |c| c.blank? }
1008
        order_columns = order_columns.zip((0...order_columns.size).to_a).map { |s,i| "#{s} AS alias_#{i}" }
1009

1010
        "DISTINCT #{columns}, #{order_columns * ', '}"
1011
      end
1012

1013
      module Utils
1014 1015
        extend self

1016 1017 1018 1019 1020 1021 1022 1023 1024 1025
        # Returns an array of <tt>[schema_name, table_name]</tt> extracted from +name+.
        # +schema_name+ is nil if not specified in +name+.
        # +schema_name+ and +table_name+ exclude surrounding quotes (regardless of whether provided in +name+)
        # +name+ supports the range of schema/table references understood by PostgreSQL, for example:
        #
        # * <tt>table_name</tt>
        # * <tt>"table.name"</tt>
        # * <tt>schema_name.table_name</tt>
        # * <tt>schema_name."table.name"</tt>
        # * <tt>"schema.name"."table name"</tt>
1026
        def extract_schema_and_table(name)
1027 1028 1029 1030 1031
          table, schema = name.scan(/[^".\s]+|"[^"]*"/)[0..1].collect{|m| m.gsub(/(^"|"$)/,'') }.reverse
          [schema, table]
        end
      end

1032
      protected
1033
        # Returns the version of the connected PostgreSQL server.
1034
        def postgresql_version
1035
          @connection.server_version
1036 1037
        end

1038 1039 1040
        def translate_exception(exception, message)
          case exception.message
          when /duplicate key value violates unique constraint/
1041
            RecordNotUnique.new(message, exception)
1042
          when /violates foreign key constraint/
1043
            InvalidForeignKey.new(message, exception)
1044 1045 1046 1047 1048
          else
            super
          end
        end

D
Initial  
David Heinemeier Hansson 已提交
1049
      private
1050 1051
        FEATURE_NOT_SUPPORTED = "0A000" # :nodoc:

1052 1053
        def exec_no_cache(sql, binds)
          @connection.async_exec(sql)
1054
        end
1055

1056
        def exec_cache(sql, binds)
1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091
          begin
            stmt_key = prepare_statement sql

            # Clear the queue
            @connection.get_last_result
            @connection.send_query_prepared(stmt_key, binds.map { |col, val|
              type_cast(val, col)
            })
            @connection.block
            @connection.get_last_result
          rescue PGError => e
            # Get the PG code for the failure.  Annoyingly, the code for
            # prepared statements whose return value may have changed is
            # FEATURE_NOT_SUPPORTED.  Check here for more details:
            # http://git.postgresql.org/gitweb/?p=postgresql.git;a=blob;f=src/backend/utils/cache/plancache.c#l573
            code = e.result.result_error_field(PGresult::PG_DIAG_SQLSTATE)
            if FEATURE_NOT_SUPPORTED == code
              @statements.delete sql_key(sql)
              retry
            else
              raise e
            end
          end
        end

        # Returns the statement identifier for the client side cache
        # of statements
        def sql_key(sql)
          "#{schema_search_path}-#{sql}"
        end

        # Prepare the statement if it hasn't been prepared, return
        # the statement key.
        def prepare_statement(sql)
          sql_key = sql_key(sql)
1092
          unless @statements.key? sql_key
1093
            nextkey = @statements.next_key
1094
            @connection.prepare nextkey, sql
1095
            @statements[sql_key] = nextkey
1096
          end
1097
          @statements[sql_key]
1098
        end
1099

P
Pratik Naik 已提交
1100
        # The internal PostgreSQL identifier of the money data type.
1101
        MONEY_COLUMN_TYPE_OID = 790 #:nodoc:
1102 1103
        # The internal PostgreSQL identifier of the BYTEA data type.
        BYTEA_COLUMN_TYPE_OID = 17 #:nodoc:
1104 1105 1106 1107 1108 1109 1110 1111 1112

        # Connects to a PostgreSQL server and sets up the adapter depending on the
        # connected server's characteristics.
        def connect
          @connection = PGconn.connect(*@connection_parameters)

          # Money type has a fixed precision of 10 in PostgreSQL 8.2 and below, and as of
          # PostgreSQL 8.3 it has a fixed precision of 19. PostgreSQLColumn.extract_precision
          # should know about this but can't detect it there, so deal with it here.
1113 1114
          PostgreSQLColumn.money_precision = (postgresql_version >= 80300) ? 19 : 10

1115 1116 1117
          configure_connection
        end

1118
        # Configures the encoding, verbosity, schema search path, and time zone of the connection.
1119
        # This is called by #connect and should not be called manually.
1120 1121
        def configure_connection
          if @config[:encoding]
1122
            @connection.set_client_encoding(@config[:encoding])
1123
          end
1124 1125
          self.client_min_messages = @config[:min_messages] if @config[:min_messages]
          self.schema_search_path = @config[:schema_search_path] || @config[:schema_order]
1126 1127 1128 1129

          # Use standard-conforming strings if available so we don't have to do the E'...' dance.
          set_standard_conforming_strings

1130
          # If using Active Record's time zone support configure the connection to return
1131
          # TIMESTAMP WITH ZONE types in UTC.
1132
          if ActiveRecord::Base.default_timezone == :utc
1133
            execute("SET time zone 'UTC'", 'SCHEMA')
1134
          elsif @local_tz
1135
            execute("SET time zone '#{@local_tz}'", 'SCHEMA')
1136
          end
1137 1138
        end

1139
        # Returns the current ID of a table's sequence.
1140 1141 1142
        def last_insert_id(sequence_name) #:nodoc:
          r = exec_query("SELECT currval($1)", 'SQL', [[nil, sequence_name]])
          Integer(r.rows.first.first)
D
Initial  
David Heinemeier Hansson 已提交
1143 1144
        end

1145
        # Executes a SELECT query and returns the results, performing any data type
1146
        # conversions that are required to be performed here instead of in PostgreSQLColumn.
1147
        def select(sql, name = nil, binds = [])
A
Aaron Patterson 已提交
1148
          exec_query(sql, name, binds).to_a
1149 1150 1151
        end

        def select_raw(sql, name = nil)
1152
          res = execute(sql, name)
1153
          results = result_as_array(res)
1154
          fields = res.fields
1155
          res.clear
1156
          return fields, results
M
Marcel Molina 已提交
1157 1158
        end

1159
        # Returns the list of a table's column names, data types, and default values.
1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176
        #
        # The underlying query is roughly:
        #  SELECT column.name, column.type, default.value
        #    FROM column LEFT JOIN default
        #      ON column.table_id = default.table_id
        #     AND column.num = default.column_num
        #   WHERE column.table_id = get_table_id('table_name')
        #     AND column.num > 0
        #     AND NOT column.is_dropped
        #   ORDER BY column.num
        #
        # If the table name is not prefixed with a schema, the database will
        # take the first match from the schema search path.
        #
        # Query implementation notes:
        #  - format_type includes the column size constraint, e.g. varchar(50)
        #  - ::regclass is a function that gives the id for a table name
1177
        def column_definitions(table_name) #:nodoc:
1178
          exec_query(<<-end_sql, 'SCHEMA').rows
1179
            SELECT a.attname, format_type(a.atttypid, a.atttypmod), d.adsrc, a.attnotnull
1180 1181
              FROM pg_attribute a LEFT JOIN pg_attrdef d
                ON a.attrelid = d.adrelid AND a.attnum = d.adnum
1182
             WHERE a.attrelid = '#{quote_table_name(table_name)}'::regclass
1183 1184 1185
               AND a.attnum > 0 AND NOT a.attisdropped
             ORDER BY a.attnum
          end_sql
D
Initial  
David Heinemeier Hansson 已提交
1186
        end
1187 1188

        def extract_pg_identifier_from_name(name)
1189
          match_data = name.start_with?('"') ? name.match(/\"([^\"]+)\"/) : name.match(/([^\.]+)/)
1190 1191

          if match_data
1192 1193
            rest = name[match_data[0].length, name.length]
            rest = rest[1, rest.length] if rest.start_with? "."
1194 1195 1196
            [match_data[1], (rest.length > 0 ? rest : nil)]
          end
        end
1197

1198 1199 1200 1201 1202
        def extract_table_ref_from_insert_sql(sql)
          sql[/into\s+([^\(]*).*values\s*\(/i]
          $1.strip if $1
        end

1203 1204 1205
        def table_definition
          TableDefinition.new(self)
        end
D
Initial  
David Heinemeier Hansson 已提交
1206 1207 1208
    end
  end
end