postgresql_adapter.rb 39.8 KB
Newer Older
D
Initial  
David Heinemeier Hansson 已提交
1 2
require 'active_record/connection_adapters/abstract_adapter'

3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
begin
  require_library_or_gem 'pg'
rescue LoadError => e
  begin
    require_library_or_gem 'postgres'
    class PGresult
      alias_method :nfields, :num_fields unless self.method_defined?(:nfields)
      alias_method :ntuples, :num_tuples unless self.method_defined?(:ntuples)
      alias_method :ftype, :type unless self.method_defined?(:ftype)
      alias_method :cmd_tuples, :cmdtuples unless self.method_defined?(:cmd_tuples)
    end
  rescue LoadError
    raise e
  end
end

D
Initial  
David Heinemeier Hansson 已提交
19 20 21 22
module ActiveRecord
  class Base
    # Establishes a connection to the database that's used by all Active Record objects
    def self.postgresql_connection(config) # :nodoc:
23
      config = config.symbolize_keys
D
Initial  
David Heinemeier Hansson 已提交
24
      host     = config[:host]
25
      port     = config[:port] || 5432
26 27
      username = config[:username].to_s if config[:username]
      password = config[:password].to_s if config[:password]
D
Initial  
David Heinemeier Hansson 已提交
28 29 30 31 32 33 34

      if config.has_key?(:database)
        database = config[:database]
      else
        raise ArgumentError, "No database specified. Missing argument: database."
      end

35
      # The postgres drivers don't allow the creation of an unconnected PGconn object,
36 37 38 39
      # so just pass a nil connection object for the time being.
      ConnectionAdapters::PostgreSQLAdapter.new(nil, logger, [host, port, nil, nil, database, username, password], config)
    end
  end
40

41 42 43 44 45 46 47
  module ConnectionAdapters
    # PostgreSQL-specific extensions to column definitions in a table.
    class PostgreSQLColumn < Column #:nodoc:
      # Instantiates a new PostgreSQL column definition in a table.
      def initialize(name, default, sql_type = nil, null = true)
        super(name, self.class.extract_value_from_default(default), sql_type, null)
      end
48

49
      private
50
        def extract_limit(sql_type)
51 52 53 54 55
          case sql_type
          when /^bigint/i;    8
          when /^smallint/i;  2
          else super
          end
56 57
        end

58 59 60 61 62
        # Extracts the scale from PostgreSQL-specific data types.
        def extract_scale(sql_type)
          # Money type has a fixed scale of 2.
          sql_type =~ /^money/ ? 2 : super
        end
63

64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118
        # Extracts the precision from PostgreSQL-specific data types.
        def extract_precision(sql_type)
          # Actual code is defined dynamically in PostgreSQLAdapter.connect
          # depending on the server specifics
          super
        end
  
        # Maps PostgreSQL-specific data types to logical Rails types.
        def simplified_type(field_type)
          case field_type
            # Numeric and monetary types
            when /^(?:real|double precision)$/
              :float
            # Monetary types
            when /^money$/
              :decimal
            # Character types
            when /^(?:character varying|bpchar)(?:\(\d+\))?$/
              :string
            # Binary data types
            when /^bytea$/
              :binary
            # Date/time types
            when /^timestamp with(?:out)? time zone$/
              :datetime
            when /^interval$/
              :string
            # Geometric types
            when /^(?:point|line|lseg|box|"?path"?|polygon|circle)$/
              :string
            # Network address types
            when /^(?:cidr|inet|macaddr)$/
              :string
            # Bit strings
            when /^bit(?: varying)?(?:\(\d+\))?$/
              :string
            # XML type
            when /^xml$/
              :string
            # Arrays
            when /^\D+\[\]$/
              :string              
            # Object identifier types
            when /^oid$/
              :integer
            # Pass through all types that are not specific to PostgreSQL.
            else
              super
          end
        end
  
        # Extracts the value from a PostgreSQL column default definition.
        def self.extract_value_from_default(default)
          case default
            # Numeric types
119 120
            when /\A\(?(-?\d+(\.\d*)?\)?)\z/
              $1
121
            # Character types
122
            when /\A'(.*)'::(?:character varying|bpchar|text)\z/m
123
              $1
124 125 126
            # Character types (8.1 formatting)
            when /\AE'(.*)'::(?:character varying|bpchar|text)\z/m
              $1.gsub(/\\(\d\d\d)/) { $1.oct.chr }
127
            # Binary data types
128
            when /\A'(.*)'::bytea\z/m
129 130
              $1
            # Date/time types
131
            when /\A'(.+)'::(?:time(?:stamp)? with(?:out)? time zone|date)\z/
132
              $1
133
            when /\A'(.*)'::interval\z/
134 135
              $1
            # Boolean type
136
            when 'true'
137
              true
138
            when 'false'
139 140
              false
            # Geometric types
141
            when /\A'(.*)'::(?:point|line|lseg|box|"?path"?|polygon|circle)\z/
142 143
              $1
            # Network address types
144
            when /\A'(.*)'::(?:cidr|inet|macaddr)\z/
145 146
              $1
            # Bit string types
147
            when /\AB'(.*)'::"?bit(?: varying)?"?\z/
148 149
              $1
            # XML type
150
            when /\A'(.*)'::xml\z/m
151 152
              $1
            # Arrays
153
            when /\A'(.*)'::"?\D+"?\[\]\z/
154 155
              $1
            # Object identifier types
156
            when /\A-?\d+\z/
157 158 159
              $1
            else
              # Anything else is blank, some user type, or some function
160
              # and we can't know the value of that, so return nil.
161 162 163
              nil
          end
        end
D
Initial  
David Heinemeier Hansson 已提交
164 165 166 167
    end
  end

  module ConnectionAdapters
168 169
    # The PostgreSQL adapter works both with the native C (http://ruby.scripting.ca/postgres/) and the pure
    # Ruby (available both as gem and from http://rubyforge.org/frs/?group_id=234&release_id=1944) drivers.
170 171 172
    #
    # Options:
    #
P
Pratik Naik 已提交
173 174 175 176 177 178 179 180 181
    # * <tt>:host</tt> - Defaults to "localhost".
    # * <tt>:port</tt> - Defaults to 5432.
    # * <tt>:username</tt> - Defaults to nothing.
    # * <tt>:password</tt> - Defaults to nothing.
    # * <tt>:database</tt> - The name of the database. No default, must be provided.
    # * <tt>:schema_search_path</tt> - An optional schema search path for the connection given as a string of comma-separated schema names.  This is backward-compatible with the <tt>:schema_order</tt> option.
    # * <tt>:encoding</tt> - An optional client encoding that is used in a <tt>SET client_encoding TO <encoding></tt> call on the connection.
    # * <tt>:min_messages</tt> - An optional client min messages that is used in a <tt>SET client_min_messages TO <min_messages></tt> call on the connection.
    # * <tt>:allow_concurrency</tt> - If true, use async query methods so Ruby threads don't deadlock; otherwise, use blocking query methods.
182
    class PostgreSQLAdapter < AbstractAdapter
183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199
      ADAPTER_NAME = 'PostgreSQL'.freeze

      NATIVE_DATABASE_TYPES = {
        :primary_key => "serial primary key".freeze,
        :string      => { :name => "character varying", :limit => 255 },
        :text        => { :name => "text" },
        :integer     => { :name => "integer" },
        :float       => { :name => "float" },
        :decimal     => { :name => "decimal" },
        :datetime    => { :name => "timestamp" },
        :timestamp   => { :name => "timestamp" },
        :time        => { :name => "time" },
        :date        => { :name => "date" },
        :binary      => { :name => "bytea" },
        :boolean     => { :name => "boolean" }
      }

200
      # Returns 'PostgreSQL' as adapter name for identification purposes.
201
      def adapter_name
202
        ADAPTER_NAME
203 204
      end

205 206
      # Initializes and connects a PostgreSQL adapter.
      def initialize(connection, logger, connection_parameters, config)
207
        super(connection, logger)
208
        @connection_parameters, @config = connection_parameters, config
209

210
        connect
211 212
      end

213 214 215
      # Is this connection alive and ready for queries?
      def active?
        if @connection.respond_to?(:status)
216
          @connection.status == PGconn::CONNECTION_OK
217
        else
218
          # We're asking the driver, not ActiveRecord, so use @connection.query instead of #query
219
          @connection.query 'SELECT 1'
220 221
          true
        end
222
      # postgres-pr raises a NoMethodError when querying if no connection is available.
223
      rescue PGError, NoMethodError
224
        false
225 226 227 228 229 230
      end

      # Close then reopen the connection.
      def reconnect!
        if @connection.respond_to?(:reset)
          @connection.reset
231
          configure_connection
232 233 234
        else
          disconnect!
          connect
235 236
        end
      end
237

238
      # Close the connection.
239 240 241
      def disconnect!
        @connection.close rescue nil
      end
242

243
      def native_database_types #:nodoc:
244
        NATIVE_DATABASE_TYPES
245
      end
246

247
      # Does PostgreSQL support migrations?
248 249
      def supports_migrations?
        true
250 251
      end

252 253 254 255 256 257 258 259 260 261 262
      # Does PostgreSQL support standard conforming strings?
      def supports_standard_conforming_strings?
        # Temporarily set the client message level above error to prevent unintentional
        # error messages in the logs when working on a PostgreSQL database server that
        # does not support standard conforming strings.
        client_min_messages_old = client_min_messages
        self.client_min_messages = 'panic'

        # postgres-pr does not raise an exception when client_min_messages is set higher
        # than error and "SHOW standard_conforming_strings" fails, but returns an empty
        # PGresult instead.
263
        has_support = query('SHOW standard_conforming_strings')[0][0] rescue false
264 265 266 267
        self.client_min_messages = client_min_messages_old
        has_support
      end

268
      def supports_insert_with_returning?
269
        postgresql_version >= 80200
270 271
      end

272 273 274
      def supports_ddl_transactions?
        true
      end
275 276 277 278
      
      def supports_savepoints?
        true
      end
279

280 281
      # Returns the configured supported identifier length supported by PostgreSQL,
      # or report the default of 63 on PostgreSQL 7.x.
282
      def table_alias_length
283
        @table_alias_length ||= (postgresql_version >= 80000 ? query('SHOW max_identifier_length')[0][0].to_i : 63)
284
      end
285

286 287
      # QUOTING ==================================================

288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355
      # Escapes binary strings for bytea input to the database.
      def escape_bytea(value)
        if PGconn.respond_to?(:escape_bytea)
          self.class.instance_eval do
            define_method(:escape_bytea) do |value|
              PGconn.escape_bytea(value) if value
            end
          end
        else
          self.class.instance_eval do
            define_method(:escape_bytea) do |value|
              if value
                result = ''
                value.each_byte { |c| result << sprintf('\\\\%03o', c) }
                result
              end
            end
          end
        end
        escape_bytea(value)
      end

      # Unescapes bytea output from a database to the binary string it represents.
      # NOTE: This is NOT an inverse of escape_bytea! This is only to be used
      #       on escaped binary output from database drive.
      def unescape_bytea(value)
        # In each case, check if the value actually is escaped PostgreSQL bytea output
        # or an unescaped Active Record attribute that was just written.
        if PGconn.respond_to?(:unescape_bytea)
          self.class.instance_eval do
            define_method(:unescape_bytea) do |value|
              if value =~ /\\\d{3}/
                PGconn.unescape_bytea(value)
              else
                value
              end
            end
          end
        else
          self.class.instance_eval do
            define_method(:unescape_bytea) do |value|
              if value =~ /\\\d{3}/
                result = ''
                i, max = 0, value.size
                while i < max
                  char = value[i]
                  if char == ?\\
                    if value[i+1] == ?\\
                      char = ?\\
                      i += 1
                    else
                      char = value[i+1..i+3].oct
                      i += 3
                    end
                  end
                  result << char
                  i += 1
                end
                result
              else
                value
              end
            end
          end
        end
        unescape_bytea(value)
      end

356 357
      # Quotes PostgreSQL-specific data types for SQL input.
      def quote(value, column = nil) #:nodoc:
358
        if value.kind_of?(String) && column && column.type == :binary
359
          "#{quoted_string_prefix}'#{escape_bytea(value)}'"
360 361 362 363 364 365 366 367 368 369 370 371
        elsif value.kind_of?(String) && column && column.sql_type =~ /^xml$/
          "xml '#{quote_string(value)}'"
        elsif value.kind_of?(Numeric) && column && column.sql_type =~ /^money$/
          # Not truly string input, so doesn't require (or allow) escape string syntax.
          "'#{value.to_s}'"
        elsif value.kind_of?(String) && column && column.sql_type =~ /^bit/
          case value
            when /^[01]*$/
              "B'#{value}'" # Bit-string notation
            when /^[0-9A-F]*$/i
              "X'#{value}'" # Hexadecimal notation
          end
372 373 374 375 376
        else
          super
        end
      end

377 378 379 380 381 382 383 384 385 386 387 388
      # Quotes strings for use in SQL input in the postgres driver for better performance.
      def quote_string(s) #:nodoc:
        if PGconn.respond_to?(:escape)
          self.class.instance_eval do
            define_method(:quote_string) do |s|
              PGconn.escape(s)
            end
          end
        else
          # There are some incorrectly compiled postgres drivers out there
          # that don't define PGconn.escape.
          self.class.instance_eval do
389
            remove_method(:quote_string)
390 391 392 393 394
          end
        end
        quote_string(s)
      end

395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413
      # Checks the following cases:
      #
      # - table_name
      # - "table.name"
      # - schema_name.table_name
      # - schema_name."table.name"
      # - "schema.name".table_name
      # - "schema.name"."table.name"
      def quote_table_name(name)
        schema, name_part = extract_pg_identifier_from_name(name.to_s)

        unless name_part
          quote_column_name(schema)
        else
          table_name, name_part = extract_pg_identifier_from_name(name_part)
          "#{quote_column_name(schema)}.#{quote_column_name(table_name)}"
        end
      end

414 415
      # Quotes column names for use in SQL queries.
      def quote_column_name(name) #:nodoc:
416
        PGconn.quote_ident(name.to_s)
417 418
      end

419 420 421
      # Quote date/time values for use in SQL input. Includes microseconds
      # if the value is a Time responding to usec.
      def quoted_date(value) #:nodoc:
422 423 424 425 426
        if value.acts_like?(:time) && value.respond_to?(:usec)
          "#{super}.#{sprintf("%06d", value.usec)}"
        else
          super
        end
427 428
      end

429 430
      # REFERENTIAL INTEGRITY ====================================

431 432 433 434 435 436 437
      def supports_disable_referential_integrity?() #:nodoc:
        version = query("SHOW server_version")[0][0].split('.')
        (version[0].to_i >= 8 && version[1].to_i >= 1) ? true : false
      rescue
        return false
      end

438
      def disable_referential_integrity(&block) #:nodoc:
439 440 441
        if supports_disable_referential_integrity?() then
          execute(tables.collect { |name| "ALTER TABLE #{quote_table_name(name)} DISABLE TRIGGER ALL" }.join(";"))
        end
442 443
        yield
      ensure
444 445 446
        if supports_disable_referential_integrity?() then
          execute(tables.collect { |name| "ALTER TABLE #{quote_table_name(name)} ENABLE TRIGGER ALL" }.join(";"))
        end
447
      end
448 449 450

      # DATABASE STATEMENTS ======================================

451 452 453 454 455 456
      # Executes a SELECT query and returns an array of rows. Each row is an
      # array of field values.
      def select_rows(sql, name = nil)
        select_raw(sql, name).last
      end

457
      # Executes an INSERT query and returns the new record's ID
458
      def insert(sql, name = nil, pk = nil, id_value = nil, sequence_name = nil)
459 460 461 462 463 464 465 466 467 468 469 470 471 472
        # Extract the table from the insert sql. Yuck.
        table = sql.split(" ", 4)[2].gsub('"', '')

        # Try an insert with 'returning id' if available (PG >= 8.2)
        if supports_insert_with_returning?
          pk, sequence_name = *pk_and_sequence_for(table) unless pk
          if pk
            id = select_value("#{sql} RETURNING #{quote_column_name(pk)}")
            clear_query_cache
            return id
          end
        end

        # Otherwise, insert then grab last_insert_id.
473 474 475 476 477 478 479 480 481 482 483 484 485 486
        if insert_id = super
          insert_id
        else
          # If neither pk nor sequence name is given, look them up.
          unless pk || sequence_name
            pk, sequence_name = *pk_and_sequence_for(table)
          end

          # If a pk is given, fallback to default sequence name.
          # Don't fetch last insert id for a table without a pk.
          if pk && sequence_name ||= default_sequence_name(table, pk)
            last_insert_id(table, sequence_name)
          end
        end
487 488
      end

489 490
      # create a 2D array representing the result set
      def result_as_array(res) #:nodoc:
491 492 493 494
        # check if we have any binary column and if they need escaping
        unescape_col = []
        for j in 0...res.nfields do
          # unescape string passed BYTEA field (OID == 17)
495
          unescape_col << ( res.ftype(j)==17 )
496 497
        end

498 499 500 501
        ary = []
        for i in 0...res.ntuples do
          ary << []
          for j in 0...res.nfields do
502 503 504
            data = res.getvalue(i,j)
            data = unescape_bytea(data) if unescape_col[j] and data.is_a?(String)
            ary[i] << data
505 506 507 508 509 510 511
          end
        end
        return ary
      end


      # Queries the database and returns the results in an Array-like object
512
      def query(sql, name = nil) #:nodoc:
513 514
        log(sql, name) do
          if @async
515
            res = @connection.async_exec(sql)
516
          else
517
            res = @connection.exec(sql)
518
          end
519
          return result_as_array(res)
520
        end
521 522
      end

523
      # Executes an SQL statement, returning a PGresult object on success
524 525
      # or raising a PGError exception otherwise.
      def execute(sql, name = nil)
526 527 528 529 530 531 532
        log(sql, name) do
          if @async
            @connection.async_exec(sql)
          else
            @connection.exec(sql)
          end
        end
533 534
      end

535
      # Executes an UPDATE query and returns the number of affected tuples.
536
      def update_sql(sql, name = nil)
537
        super.cmd_tuples
538 539
      end

540 541
      # Begins a transaction.
      def begin_db_transaction
542 543 544
        execute "BEGIN"
      end

545 546
      # Commits a transaction.
      def commit_db_transaction
547 548
        execute "COMMIT"
      end
549

550 551
      # Aborts a transaction.
      def rollback_db_transaction
552 553
        execute "ROLLBACK"
      end
554
      
555 556 557
      if defined?(PGconn::PQTRANS_IDLE)
        # The ruby-pg driver supports inspecting the transaction status,
        # while the ruby-postgres driver does not.
558
        def outside_transaction?
559
          @connection.transaction_status == PGconn::PQTRANS_IDLE
560 561
        end
      end
562

J
Jonathan Viney 已提交
563 564 565 566 567 568 569 570
      def create_savepoint
        execute("SAVEPOINT #{current_savepoint_name}")
      end

      def rollback_to_savepoint
        execute("ROLLBACK TO SAVEPOINT #{current_savepoint_name}")
      end

571
      def release_savepoint
J
Jonathan Viney 已提交
572 573
        execute("RELEASE SAVEPOINT #{current_savepoint_name}")
      end
574

575 576
      # SCHEMA STATEMENTS ========================================

577 578 579 580 581
      def recreate_database(name) #:nodoc:
        drop_database(name)
        create_database(name)
      end

582 583 584
      # Create a new PostgreSQL database.  Options include <tt>:owner</tt>, <tt>:template</tt>,
      # <tt>:encoding</tt>, <tt>:tablespace</tt>, and <tt>:connection_limit</tt> (note that MySQL uses
      # <tt>:charset</tt> while PostgreSQL uses <tt>:encoding</tt>).
585 586 587 588 589 590 591 592 593 594
      #
      # Example:
      #   create_database config[:database], config
      #   create_database 'foo_development', :encoding => 'unicode'
      def create_database(name, options = {})
        options = options.reverse_merge(:encoding => "utf8")

        option_string = options.symbolize_keys.sum do |key, value|
          case key
          when :owner
595
            " OWNER = \"#{value}\""
596
          when :template
597
            " TEMPLATE = \"#{value}\""
598 599 600
          when :encoding
            " ENCODING = '#{value}'"
          when :tablespace
601
            " TABLESPACE = \"#{value}\""
602 603 604 605 606 607 608
          when :connection_limit
            " CONNECTION LIMIT = #{value}"
          else
            ""
          end
        end

609
        execute "CREATE DATABASE #{quote_table_name(name)}#{option_string}"
610 611 612 613 614 615 616
      end

      # Drops a PostgreSQL database
      #
      # Example:
      #   drop_database 'matt_development'
      def drop_database(name) #:nodoc:
617 618 619 620 621 622 623 624 625
        if postgresql_version >= 80200
          execute "DROP DATABASE IF EXISTS #{quote_table_name(name)}"
        else
          begin
            execute "DROP DATABASE #{quote_table_name(name)}"
          rescue ActiveRecord::StatementInvalid
            @logger.warn "#{name} database doesn't exist." if @logger
          end
        end
626 627 628
      end


629 630
      # Returns the list of all tables in the schema search path or a specified schema.
      def tables(name = nil)
631 632 633 634 635 636 637 638
        schemas = schema_search_path.split(/,/).map { |p| quote(p) }.join(',')
        query(<<-SQL, name).map { |row| row[0] }
          SELECT tablename
            FROM pg_tables
           WHERE schemaname IN (#{schemas})
        SQL
      end

639 640
      # Returns the list of all indexes for a table.
      def indexes(table_name, name = nil)
641 642 643 644
         schemas = schema_search_path.split(/,/).map { |p| quote(p) }.join(',')
         result = query(<<-SQL, name)
           SELECT distinct i.relname, d.indisunique, a.attname
             FROM pg_class t, pg_class i, pg_index d, pg_attribute a
645 646 647 648 649
           WHERE i.relkind = 'i'
             AND d.indexrelid = i.oid
             AND d.indisprimary = 'f'
             AND t.oid = d.indrelid
             AND t.relname = '#{table_name}'
650
             AND i.relnamespace IN (SELECT oid FROM pg_namespace WHERE nspname IN (#{schemas}) )
651
             AND a.attrelid = t.oid
652 653 654 655 656
             AND ( d.indkey[0]=a.attnum OR d.indkey[1]=a.attnum
                OR d.indkey[2]=a.attnum OR d.indkey[3]=a.attnum
                OR d.indkey[4]=a.attnum OR d.indkey[5]=a.attnum
                OR d.indkey[6]=a.attnum OR d.indkey[7]=a.attnum
                OR d.indkey[8]=a.attnum OR d.indkey[9]=a.attnum )
657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674
          ORDER BY i.relname
        SQL

        current_index = nil
        indexes = []

        result.each do |row|
          if current_index != row[0]
            indexes << IndexDefinition.new(table_name, row[0], row[1] == "t", [])
            current_index = row[0]
          end

          indexes.last.columns << row[2]
        end

        indexes
      end

675 676
      # Returns the list of all column definitions for a table.
      def columns(table_name, name = nil)
677
        # Limit, precision, and scale are all handled by the superclass.
678 679
        column_definitions(table_name).collect do |name, type, default, notnull|
          PostgreSQLColumn.new(name, default, type, notnull == 'f')
D
Initial  
David Heinemeier Hansson 已提交
680 681 682
        end
      end

683 684 685 686 687 688 689 690 691 692 693 694 695
      # Returns the current database name.
      def current_database
        query('select current_database()')[0][0]
      end

      # Returns the current database encoding format.
      def encoding
        query(<<-end_sql)[0][0]
          SELECT pg_encoding_to_char(pg_database.encoding) FROM pg_database
          WHERE pg_database.datname LIKE '#{current_database}'
        end_sql
      end

696 697 698 699 700 701
      # Sets the schema search path to a string of comma-separated schema names.
      # Names beginning with $ have to be quoted (e.g. $user => '$user').
      # See: http://www.postgresql.org/docs/current/static/ddl-schemas.html
      #
      # This should be not be called manually but set in database.yml.
      def schema_search_path=(schema_csv)
702 703
        if schema_csv
          execute "SET search_path TO #{schema_csv}"
704
          @schema_search_path = schema_csv
705
        end
D
Initial  
David Heinemeier Hansson 已提交
706 707
      end

708 709
      # Returns the active schema search path.
      def schema_search_path
710
        @schema_search_path ||= query('SHOW search_path')[0][0]
711
      end
712

713 714 715 716 717 718 719 720 721 722 723 724
      # Returns the current client message level.
      def client_min_messages
        query('SHOW client_min_messages')[0][0]
      end

      # Set the client message level.
      def client_min_messages=(level)
        execute("SET client_min_messages TO '#{level}'")
      end

      # Returns the sequence name for a table's primary key or some other specified key.
      def default_sequence_name(table_name, pk = nil) #:nodoc:
725
        default_pk, default_seq = pk_and_sequence_for(table_name)
726
        default_seq || "#{table_name}_#{pk || default_pk || 'id'}_seq"
727 728
      end

729 730
      # Resets the sequence of a table's primary key to the maximum value.
      def reset_pk_sequence!(table, pk = nil, sequence = nil) #:nodoc:
731 732 733 734 735 736 737
        unless pk and sequence
          default_pk, default_sequence = pk_and_sequence_for(table)
          pk ||= default_pk
          sequence ||= default_sequence
        end
        if pk
          if sequence
738 739
            quoted_sequence = quote_column_name(sequence)

740
            select_value <<-end_sql, 'Reset sequence'
741
              SELECT setval('#{quoted_sequence}', (SELECT COALESCE(MAX(#{quote_column_name pk})+(SELECT increment_by FROM #{quoted_sequence}), (SELECT min_value FROM #{quoted_sequence})) FROM #{quote_table_name(table)}), false)
742 743 744 745
            end_sql
          else
            @logger.warn "#{table} has primary key #{pk} with no default sequence" if @logger
          end
746 747 748
        end
      end

749 750
      # Returns a table's primary key and belonging sequence.
      def pk_and_sequence_for(table) #:nodoc:
751 752
        # First try looking for a sequence with a dependency on the
        # given table's primary key.
753
        result = query(<<-end_sql, 'PK and serial sequence')[0]
754
          SELECT attr.attname, seq.relname
755 756 757 758 759 760 761 762 763 764 765 766
          FROM pg_class      seq,
               pg_attribute  attr,
               pg_depend     dep,
               pg_namespace  name,
               pg_constraint cons
          WHERE seq.oid           = dep.objid
            AND seq.relkind       = 'S'
            AND attr.attrelid     = dep.refobjid
            AND attr.attnum       = dep.refobjsubid
            AND attr.attrelid     = cons.conrelid
            AND attr.attnum       = cons.conkey[1]
            AND cons.contype      = 'p'
767
            AND dep.refobjid      = '#{quote_table_name(table)}'::regclass
768
        end_sql
769 770 771 772 773

        if result.nil? or result.empty?
          # If that fails, try parsing the primary key's default value.
          # Support the 7.x and 8.0 nextval('foo'::text) as well as
          # the 8.1+ nextval('foo'::regclass).
774
          result = query(<<-end_sql, 'PK and custom sequence')[0]
775 776 777 778 779 780 781
            SELECT attr.attname,
              CASE
                WHEN split_part(def.adsrc, '''', 2) ~ '.' THEN
                  substr(split_part(def.adsrc, '''', 2),
                         strpos(split_part(def.adsrc, '''', 2), '.')+1)
                ELSE split_part(def.adsrc, '''', 2)
              END
782 783 784 785
            FROM pg_class       t
            JOIN pg_attribute   attr ON (t.oid = attrelid)
            JOIN pg_attrdef     def  ON (adrelid = attrelid AND adnum = attnum)
            JOIN pg_constraint  cons ON (conrelid = adrelid AND adnum = conkey[1])
786
            WHERE t.oid = '#{quote_table_name(table)}'::regclass
787
              AND cons.contype = 'p'
788
              AND def.adsrc ~* 'nextval'
789 790
          end_sql
        end
791

792
        # [primary_key, sequence]
793
        [result.first, result.last]
794 795
      rescue
        nil
796 797
      end

798
      # Renames a table.
799
      def rename_table(name, new_name)
800
        execute "ALTER TABLE #{quote_table_name(name)} RENAME TO #{quote_table_name(new_name)}"
801
      end
802

803 804
      # Adds a new column to the named table.
      # See TableDefinition#column for details of the options you can use.
S
Scott Barron 已提交
805
      def add_column(table_name, column_name, type, options = {})
806 807 808 809
        default = options[:default]
        notnull = options[:null] == false

        # Add the column.
810
        execute("ALTER TABLE #{quote_table_name(table_name)} ADD COLUMN #{quote_column_name(column_name)} #{type_to_sql(type, options[:limit], options[:precision], options[:scale])}")
811

812 813
        change_column_default(table_name, column_name, default) if options_include_default?(options)
        change_column_null(table_name, column_name, false, default) if notnull
S
Scott Barron 已提交
814
      end
D
Initial  
David Heinemeier Hansson 已提交
815

816 817
      # Changes the column of a table.
      def change_column(table_name, column_name, type, options = {})
818 819
        quoted_table_name = quote_table_name(table_name)

820
        begin
821
          execute "ALTER TABLE #{quoted_table_name} ALTER COLUMN #{quote_column_name(column_name)} TYPE #{type_to_sql(type, options[:limit], options[:precision], options[:scale])}"
822 823
        rescue ActiveRecord::StatementInvalid => e
          raise e if postgresql_version > 80000
824
          # This is PostgreSQL 7.x, so we have to use a more arcane way of doing it.
825 826 827 828 829 830 831 832 833 834 835
          begin
            begin_db_transaction
            tmp_column_name = "#{column_name}_ar_tmp"
            add_column(table_name, tmp_column_name, type, options)
            execute "UPDATE #{quoted_table_name} SET #{quote_column_name(tmp_column_name)} = CAST(#{quote_column_name(column_name)} AS #{type_to_sql(type, options[:limit], options[:precision], options[:scale])})"
            remove_column(table_name, column_name)
            rename_column(table_name, tmp_column_name, column_name)
            commit_db_transaction
          rescue
            rollback_db_transaction
          end
836
        end
837

838 839
        change_column_default(table_name, column_name, options[:default]) if options_include_default?(options)
        change_column_null(table_name, column_name, options[:null], options[:default]) if options.key?(:null)
840
      end
841

842 843
      # Changes the default value of a table column.
      def change_column_default(table_name, column_name, default)
844
        execute "ALTER TABLE #{quote_table_name(table_name)} ALTER COLUMN #{quote_column_name(column_name)} SET DEFAULT #{quote(default)}"
845
      end
846

847 848
      def change_column_null(table_name, column_name, null, default = nil)
        unless null || default.nil?
849
          execute("UPDATE #{quote_table_name(table_name)} SET #{quote_column_name(column_name)}=#{quote(default)} WHERE #{quote_column_name(column_name)} IS NULL")
850
        end
851
        execute("ALTER TABLE #{quote_table_name(table_name)} ALTER #{quote_column_name(column_name)} #{null ? 'DROP' : 'SET'} NOT NULL")
852 853
      end

854 855
      # Renames a column in a table.
      def rename_column(table_name, column_name, new_column_name)
856
        execute "ALTER TABLE #{quote_table_name(table_name)} RENAME COLUMN #{quote_column_name(column_name)} TO #{quote_column_name(new_column_name)}"
857
      end
858

859 860
      # Drops an index from a table.
      def remove_index(table_name, options = {})
861
        execute "DROP INDEX #{quote_table_name(index_name(table_name, options))}"
862
      end
863

864 865
      # Maps logical Rails types to PostgreSQL-specific data types.
      def type_to_sql(type, limit = nil, precision = nil, scale = nil)
866 867
        return super unless type.to_s == 'integer'

868 869 870 871
        case limit
          when 1..2;      'smallint'
          when 3..4, nil; 'integer'
          when 5..8;      'bigint'
872
          else raise(ActiveRecordError, "No integer type has byte size #{limit}. Use a numeric with precision 0 instead.")
873 874
        end
      end
875

876
      # Returns a SELECT DISTINCT clause for a given set of columns and a given ORDER BY clause.
877 878 879
      #
      # PostgreSQL requires the ORDER BY columns in the select list for distinct queries, and
      # requires that the ORDER BY include the distinct column.
880
      #
881
      #   distinct("posts.id", "posts.created_at desc")
882
      def distinct(columns, order_by) #:nodoc:
883 884
        return "DISTINCT #{columns}" if order_by.blank?

885 886
        # Construct a clean list of column names from the ORDER BY clause, removing
        # any ASC/DESC modifiers
887
        order_columns = order_by.split(',').collect { |s| s.split.first }
888
        order_columns.delete_if &:blank?
889
        order_columns = order_columns.zip((0...order_columns.size).to_a).map { |s,i| "#{s} AS alias_#{i}" }
890

891 892
        # Return a DISTINCT ON() clause that's distinct on the columns we want but includes
        # all the required columns for the ORDER BY to work properly.
893 894
        sql = "DISTINCT ON (#{columns}) #{columns}, "
        sql << order_columns * ', '
895
      end
896
      
897
      # Returns an ORDER BY clause for the passed order option.
898 899
      # 
      # PostgreSQL does not allow arbitrary ordering when using DISTINCT ON, so we work around this
P
Pratik Naik 已提交
900
      # by wrapping the +sql+ string as a sub-select and ordering in that query.
901
      def add_order_by_for_association_limiting!(sql, options) #:nodoc:
902 903 904 905 906 907 908 909
        return sql if options[:order].blank?
        
        order = options[:order].split(',').collect { |s| s.strip }.reject(&:blank?)
        order.map! { |s| 'DESC' if s =~ /\bdesc$/i }
        order = order.zip((0...order.size).to_a).map { |s,i| "id_list.alias_#{i} #{s}" }.join(', ')
        
        sql.replace "SELECT * FROM (#{sql}) AS id_list ORDER BY #{order}"
      end
910

911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927
      protected
        # Returns the version of the connected PostgreSQL version.
        def postgresql_version
          @postgresql_version ||=
            if @connection.respond_to?(:server_version)
              @connection.server_version
            else
              # Mimic PGconn.server_version behavior
              begin
                query('SELECT version()')[0][0] =~ /PostgreSQL (\d+)\.(\d+)\.(\d+)/
                ($1.to_i * 10000) + ($2.to_i * 100) + $3.to_i
              rescue
                0
              end
            end
        end

D
Initial  
David Heinemeier Hansson 已提交
928
      private
P
Pratik Naik 已提交
929
        # The internal PostgreSQL identifier of the money data type.
930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950
        MONEY_COLUMN_TYPE_OID = 790 #:nodoc:

        # Connects to a PostgreSQL server and sets up the adapter depending on the
        # connected server's characteristics.
        def connect
          @connection = PGconn.connect(*@connection_parameters)
          PGconn.translate_results = false if PGconn.respond_to?(:translate_results=)

          # Ignore async_exec and async_query when using postgres-pr.
          @async = @config[:allow_concurrency] && @connection.respond_to?(:async_exec)

          # Use escape string syntax if available. We cannot do this lazily when encountering
          # the first string, because that could then break any transactions in progress.
          # See: http://www.postgresql.org/docs/current/static/runtime-config-compatible.html
          # If PostgreSQL doesn't know the standard_conforming_strings parameter then it doesn't
          # support escape string syntax. Don't override the inherited quoted_string_prefix.
          if supports_standard_conforming_strings?
            self.class.instance_eval do
              define_method(:quoted_string_prefix) { 'E' }
            end
          end
951

952 953 954 955 956
          # Money type has a fixed precision of 10 in PostgreSQL 8.2 and below, and as of
          # PostgreSQL 8.3 it has a fixed precision of 19. PostgreSQLColumn.extract_precision
          # should know about this but can't detect it there, so deal with it here.
          money_precision = (postgresql_version >= 80300) ? 19 : 10
          PostgreSQLColumn.module_eval(<<-end_eval)
957 958 959 960 961 962 963
            def extract_precision(sql_type)  # def extract_precision(sql_type)
              if sql_type =~ /^money$/       #   if sql_type =~ /^money$/
                #{money_precision}           #     19
              else                           #   else
                super                        #     super
              end                            #   end
            end                              # end
964 965 966 967 968 969 970
          end_eval

          configure_connection
        end

        # Configures the encoding, verbosity, and schema search path of the connection.
        # This is called by #connect and should not be called manually.
971 972
        def configure_connection
          if @config[:encoding]
973 974 975 976 977
            if @connection.respond_to?(:set_client_encoding)
              @connection.set_client_encoding(@config[:encoding])
            else
              execute("SET client_encoding TO '#{@config[:encoding]}'")
            end
978
          end
979 980
          self.client_min_messages = @config[:min_messages] if @config[:min_messages]
          self.schema_search_path = @config[:schema_search_path] || @config[:schema_order]
981 982
        end

983 984
        # Returns the current ID of a table's sequence.
        def last_insert_id(table, sequence_name) #:nodoc:
985
          Integer(select_value("SELECT currval('#{sequence_name}')"))
D
Initial  
David Heinemeier Hansson 已提交
986 987
        end

988
        # Executes a SELECT query and returns the results, performing any data type
989
        # conversions that are required to be performed here instead of in PostgreSQLColumn.
D
Initial  
David Heinemeier Hansson 已提交
990
        def select(sql, name = nil)
991 992 993 994 995 996 997 998 999 1000 1001 1002 1003
          fields, rows = select_raw(sql, name)
          result = []
          for row in rows
            row_hash = {}
            fields.each_with_index do |f, i|
              row_hash[f] = row[i]
            end
            result << row_hash
          end
          result
        end

        def select_raw(sql, name = nil)
1004
          res = execute(sql, name)
1005
          results = result_as_array(res)
1006
          fields = []
M
Marcel Molina 已提交
1007
          rows = []
1008
          if res.ntuples > 0
M
Marcel Molina 已提交
1009 1010 1011
            fields = res.fields
            results.each do |row|
              hashed_row = {}
1012 1013 1014
              row.each_index do |cell_index|
                # If this is a money type column and there are any currency symbols,
                # then strip them off. Indeed it would be prettier to do this in
1015
                # PostgreSQLColumn.string_to_decimal but would break form input
1016
                # fields that call value_before_type_cast.
1017
                if res.ftype(cell_index) == MONEY_COLUMN_TYPE_OID
1018
                  # Because money output is formatted according to the locale, there are two
1019
                  # cases to consider (note the decimal separators):
1020 1021
                  #  (1) $12,345,678.12        
                  #  (2) $12.345.678,12
1022
                  case column = row[cell_index]
1023
                    when /^-?\D+[\d,]+\.\d{2}$/  # (1)
1024
                      row[cell_index] = column.gsub(/[^-\d\.]/, '')
1025
                    when /^-?\D+[\d\.]+,\d{2}$/  # (2)
1026
                      row[cell_index] = column.gsub(/[^-\d,]/, '').sub(/,/, '.')
1027
                  end
M
Marcel Molina 已提交
1028
                end
1029

1030
                hashed_row[fields[cell_index]] = column
M
Marcel Molina 已提交
1031
              end
1032
              rows << row
M
Marcel Molina 已提交
1033 1034
            end
          end
1035
          res.clear
1036
          return fields, rows
M
Marcel Molina 已提交
1037 1038
        end

1039
        # Returns the list of a table's column names, data types, and default values.
1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056
        #
        # The underlying query is roughly:
        #  SELECT column.name, column.type, default.value
        #    FROM column LEFT JOIN default
        #      ON column.table_id = default.table_id
        #     AND column.num = default.column_num
        #   WHERE column.table_id = get_table_id('table_name')
        #     AND column.num > 0
        #     AND NOT column.is_dropped
        #   ORDER BY column.num
        #
        # If the table name is not prefixed with a schema, the database will
        # take the first match from the schema search path.
        #
        # Query implementation notes:
        #  - format_type includes the column size constraint, e.g. varchar(50)
        #  - ::regclass is a function that gives the id for a table name
1057
        def column_definitions(table_name) #:nodoc:
1058
          query <<-end_sql
1059
            SELECT a.attname, format_type(a.atttypid, a.atttypmod), d.adsrc, a.attnotnull
1060 1061
              FROM pg_attribute a LEFT JOIN pg_attrdef d
                ON a.attrelid = d.adrelid AND a.attnum = d.adnum
1062
             WHERE a.attrelid = '#{quote_table_name(table_name)}'::regclass
1063 1064 1065
               AND a.attnum > 0 AND NOT a.attisdropped
             ORDER BY a.attnum
          end_sql
D
Initial  
David Heinemeier Hansson 已提交
1066
        end
1067 1068 1069 1070 1071 1072 1073 1074 1075 1076

        def extract_pg_identifier_from_name(name)
          match_data = name[0,1] == '"' ? name.match(/\"([^\"]+)\"/) : name.match(/([^\.]+)/)

          if match_data
            rest = name[match_data[0].length..-1]
            rest = rest[1..-1] if rest[0,1] == "."
            [match_data[1], (rest.length > 0 ? rest : nil)]
          end
        end
D
Initial  
David Heinemeier Hansson 已提交
1077 1078 1079
    end
  end
end