postgresql_adapter.rb 20.9 KB
Newer Older
D
Initial  
David Heinemeier Hansson 已提交
1 2 3 4 5 6 7
require 'active_record/connection_adapters/abstract_adapter'

module ActiveRecord
  class Base
    # Establishes a connection to the database that's used by all Active Record objects
    def self.postgresql_connection(config) # :nodoc:
      require_library_or_gem 'postgres' unless self.class.const_defined?(:PGconn)
8 9

      config = config.symbolize_keys
D
Initial  
David Heinemeier Hansson 已提交
10
      host     = config[:host]
11
      port     = config[:port] || 5432
D
Initial  
David Heinemeier Hansson 已提交
12 13 14
      username = config[:username].to_s
      password = config[:password].to_s

15
      min_messages = config[:min_messages]
16

D
Initial  
David Heinemeier Hansson 已提交
17 18 19 20 21 22
      if config.has_key?(:database)
        database = config[:database]
      else
        raise ArgumentError, "No database specified. Missing argument: database."
      end

23
      pga = ConnectionAdapters::PostgreSQLAdapter.new(
24
        PGconn.connect(host, port, "", "", database, username, password), logger, config
D
Initial  
David Heinemeier Hansson 已提交
25
      )
26

27 28
      PGconn.translate_results = false if PGconn.respond_to? :translate_results=

29
      pga.schema_search_path = config[:schema_search_path] || config[:schema_order]
30 31

      pga
D
Initial  
David Heinemeier Hansson 已提交
32 33 34 35
    end
  end

  module ConnectionAdapters
36 37 38 39 40 41 42 43 44 45
    # The PostgreSQL adapter works both with the C-based (http://www.postgresql.jp/interfaces/ruby/) and the Ruby-base
    # (available both as gem and from http://rubyforge.org/frs/?group_id=234&release_id=1145) drivers.
    #
    # Options:
    #
    # * <tt>:host</tt> -- Defaults to localhost
    # * <tt>:port</tt> -- Defaults to 5432
    # * <tt>:username</tt> -- Defaults to nothing
    # * <tt>:password</tt> -- Defaults to nothing
    # * <tt>:database</tt> -- The name of the database. No default, must be provided.
46
    # * <tt>:schema_search_path</tt> -- An optional schema search path for the connection given as a string of comma-separated schema names.  This is backward-compatible with the :schema_order option.
47 48
    # * <tt>:encoding</tt> -- An optional client encoding that is using in a SET client_encoding TO <encoding> call on connection.
    # * <tt>:min_messages</tt> -- An optional client min messages that is using in a SET client_min_messages TO <min_messages> call on connection.
49
    # * <tt>:allow_concurrency</tt> -- If true, use async query methods so Ruby threads don't deadlock; otherwise, use blocking query methods.
50
    class PostgreSQLAdapter < AbstractAdapter
51 52 53 54
      def adapter_name
        'PostgreSQL'
      end

55 56 57
      def initialize(connection, logger, config = {})
        super(connection, logger)
        @config = config
58
        @async = config[:allow_concurrency]
59 60 61
        configure_connection
      end

62 63 64
      # Is this connection alive and ready for queries?
      def active?
        if @connection.respond_to?(:status)
65
          @connection.status == PGconn::CONNECTION_OK
66
        else
67
          @connection.query 'SELECT 1'
68 69
          true
        end
70 71
      # postgres-pr raises a NoMethodError when querying if no conn is available
      rescue PGError, NoMethodError
72
        false
73 74 75 76 77 78 79
      end

      # Close then reopen the connection.
      def reconnect!
        # TODO: postgres-pr doesn't have PGconn#reset.
        if @connection.respond_to?(:reset)
          @connection.reset
80
          configure_connection
81 82
        end
      end
83

84 85 86 87
      def disconnect!
        # Both postgres and postgres-pr respond to :close
        @connection.close rescue nil
      end
88

89 90 91 92 93 94 95
      def native_database_types
        {
          :primary_key => "serial primary key",
          :string      => { :name => "character varying", :limit => 255 },
          :text        => { :name => "text" },
          :integer     => { :name => "integer" },
          :float       => { :name => "float" },
96
          :decimal     => { :name => "decimal" },
97 98
          :datetime    => { :name => "timestamp" },
          :timestamp   => { :name => "timestamp" },
99
          :time        => { :name => "time" },
100 101
          :date        => { :name => "date" },
          :binary      => { :name => "bytea" },
102
          :boolean     => { :name => "boolean" }
103 104
        }
      end
105

106 107
      def supports_migrations?
        true
108 109
      end

110 111 112
      def table_alias_length
        63
      end
113

114 115 116
      # QUOTING ==================================================

      def quote(value, column = nil)
117
        if value.kind_of?(String) && column && column.type == :binary
118
          "'#{escape_bytea(value)}'"
119 120 121 122 123 124 125 126 127
        else
          super
        end
      end

      def quote_column_name(name)
        %("#{name}")
      end

128
      def quoted_date(value)
129
        value.strftime("%Y-%m-%d %H:%M:%S.#{sprintf("%06d", value.usec)}")
130 131
      end

132 133 134 135 136 137

      # DATABASE STATEMENTS ======================================

      def insert(sql, name = nil, pk = nil, id_value = nil, sequence_name = nil) #:nodoc:
        execute(sql, name)
        table = sql.split(" ", 4)[2]
138
        id_value || last_insert_id(table, sequence_name || default_sequence_name(table, pk))
139 140 141
      end

      def query(sql, name = nil) #:nodoc:
142 143 144 145 146 147 148
        log(sql, name) do
          if @async
            @connection.async_query(sql)
          else
            @connection.query(sql)
          end
        end
149 150 151
      end

      def execute(sql, name = nil) #:nodoc:
152 153 154 155 156 157 158
        log(sql, name) do
          if @async
            @connection.async_exec(sql)
          else
            @connection.exec(sql)
          end
        end
159 160 161 162 163 164 165 166 167 168 169 170 171
      end

      def update(sql, name = nil) #:nodoc:
        execute(sql, name).cmdtuples
      end

      def begin_db_transaction #:nodoc:
        execute "BEGIN"
      end

      def commit_db_transaction #:nodoc:
        execute "COMMIT"
      end
172

173 174 175 176 177 178
      def rollback_db_transaction #:nodoc:
        execute "ROLLBACK"
      end

      # SCHEMA STATEMENTS ========================================

179
      # Return the list of all tables in the schema search path.
180
      def tables(name = nil) #:nodoc:
181 182 183 184 185 186 187 188
        schemas = schema_search_path.split(/,/).map { |p| quote(p) }.join(',')
        query(<<-SQL, name).map { |row| row[0] }
          SELECT tablename
            FROM pg_tables
           WHERE schemaname IN (#{schemas})
        SQL
      end

189
      def indexes(table_name, name = nil) #:nodoc:
190 191 192 193 194 195 196 197 198
        result = query(<<-SQL, name)
          SELECT i.relname, d.indisunique, a.attname
            FROM pg_class t, pg_class i, pg_index d, pg_attribute a
           WHERE i.relkind = 'i'
             AND d.indexrelid = i.oid
             AND d.indisprimary = 'f'
             AND t.oid = d.indrelid
             AND t.relname = '#{table_name}'
             AND a.attrelid = t.oid
199 200 201 202 203
             AND ( d.indkey[0]=a.attnum OR d.indkey[1]=a.attnum
                OR d.indkey[2]=a.attnum OR d.indkey[3]=a.attnum
                OR d.indkey[4]=a.attnum OR d.indkey[5]=a.attnum
                OR d.indkey[6]=a.attnum OR d.indkey[7]=a.attnum
                OR d.indkey[8]=a.attnum OR d.indkey[9]=a.attnum )
204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221
          ORDER BY i.relname
        SQL

        current_index = nil
        indexes = []

        result.each do |row|
          if current_index != row[0]
            indexes << IndexDefinition.new(table_name, row[0], row[1] == "t", [])
            current_index = row[0]
          end

          indexes.last.columns << row[2]
        end

        indexes
      end

222
      def columns(table_name, name = nil) #:nodoc:
223 224
        column_definitions(table_name).collect do |name, type, default, notnull, typmod|
          # typmod now unused as limit, precision, scale all handled by superclass
J
Jeremy Kemper 已提交
225
          Column.new(name, default_value(default), translate_field_type(type), notnull == "f")
D
Initial  
David Heinemeier Hansson 已提交
226 227 228
        end
      end

229 230 231 232 233 234 235 236
      # Set the schema search path to a string of comma-separated schema names.
      # Names beginning with $ are quoted (e.g. $user => '$user')
      # See http://www.postgresql.org/docs/8.0/interactive/ddl-schemas.html
      def schema_search_path=(schema_csv) #:nodoc:
        if schema_csv
          execute "SET search_path TO #{schema_csv}"
          @schema_search_path = nil
        end
D
Initial  
David Heinemeier Hansson 已提交
237 238
      end

239 240
      def schema_search_path #:nodoc:
        @schema_search_path ||= query('SHOW search_path')[0][0]
241
      end
242

243 244
      def default_sequence_name(table_name, pk = nil)
        default_pk, default_seq = pk_and_sequence_for(table_name)
245
        default_seq || "#{table_name}_#{pk || default_pk || 'id'}_seq"
246 247
      end

248 249 250 251 252 253 254 255 256 257 258 259 260 261 262
      # Resets sequence to the max value of the table's pk if present.
      def reset_pk_sequence!(table, pk = nil, sequence = nil)
        unless pk and sequence
          default_pk, default_sequence = pk_and_sequence_for(table)
          pk ||= default_pk
          sequence ||= default_sequence
        end
        if pk
          if sequence
            select_value <<-end_sql, 'Reset sequence'
              SELECT setval('#{sequence}', (SELECT COALESCE(MAX(#{pk})+(SELECT increment_by FROM #{sequence}), (SELECT min_value FROM #{sequence})) FROM #{table}), false)
            end_sql
          else
            @logger.warn "#{table} has primary key #{pk} with no default sequence" if @logger
          end
263 264 265 266
        end
      end

      # Find a table's primary key and sequence.
267
      def pk_and_sequence_for(table)
268 269
        # First try looking for a sequence with a dependency on the
        # given table's primary key.
270
        result = query(<<-end_sql, 'PK and serial sequence')[0]
271
          SELECT attr.attname, name.nspname, seq.relname
272 273 274 275 276 277 278 279 280 281 282 283 284 285
          FROM pg_class      seq,
               pg_attribute  attr,
               pg_depend     dep,
               pg_namespace  name,
               pg_constraint cons
          WHERE seq.oid           = dep.objid
            AND seq.relnamespace  = name.oid
            AND seq.relkind       = 'S'
            AND attr.attrelid     = dep.refobjid
            AND attr.attnum       = dep.refobjsubid
            AND attr.attrelid     = cons.conrelid
            AND attr.attnum       = cons.conkey[1]
            AND cons.contype      = 'p'
            AND dep.refobjid      = '#{table}'::regclass
286
        end_sql
287 288 289 290 291 292

        if result.nil? or result.empty?
          # If that fails, try parsing the primary key's default value.
          # Support the 7.x and 8.0 nextval('foo'::text) as well as
          # the 8.1+ nextval('foo'::regclass).
          # TODO: assumes sequence is in same schema as table.
293
          result = query(<<-end_sql, 'PK and custom sequence')[0]
294
            SELECT attr.attname, name.nspname, split_part(def.adsrc, '''', 2)
295 296 297 298 299 300 301
            FROM pg_class       t
            JOIN pg_namespace   name ON (t.relnamespace = name.oid)
            JOIN pg_attribute   attr ON (t.oid = attrelid)
            JOIN pg_attrdef     def  ON (adrelid = attrelid AND adnum = attnum)
            JOIN pg_constraint  cons ON (conrelid = adrelid AND adnum = conkey[1])
            WHERE t.oid = '#{table}'::regclass
              AND cons.contype = 'p'
302
              AND def.adsrc ~* 'nextval'
303 304
          end_sql
        end
305 306 307
        # check for existence of . in sequence name as in public.foo_sequence.  if it does not exist, return unqualified sequence
        # We cannot qualify unqualified sequences, as rails doesn't qualify any table access, using the search path
        [result.first, result.last]
308 309
      rescue
        nil
310 311
      end

312 313 314
      def rename_table(name, new_name)
        execute "ALTER TABLE #{name} RENAME TO #{new_name}"
      end
315

S
Scott Barron 已提交
316
      def add_column(table_name, column_name, type, options = {})
317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333
        default = options[:default]
        notnull = options[:null] == false

        # Add the column.
        execute("ALTER TABLE #{table_name} ADD COLUMN #{column_name} #{type_to_sql(type, options[:limit])}")

        # Set optional default. If not null, update nulls to the new default.
        unless default.nil?
          change_column_default(table_name, column_name, default)
          if notnull
            execute("UPDATE #{table_name} SET #{column_name}='#{default}' WHERE #{column_name} IS NULL")
          end
        end

        if notnull
          execute("ALTER TABLE #{table_name} ALTER #{column_name} SET NOT NULL")
        end
S
Scott Barron 已提交
334
      end
D
Initial  
David Heinemeier Hansson 已提交
335

336
      def change_column(table_name, column_name, type, options = {}) #:nodoc:
337
        begin
338
          execute "ALTER TABLE #{table_name} ALTER COLUMN #{column_name} TYPE #{type_to_sql(type, options[:limit], options[:precision], options[:scale])}"
339 340 341 342
        rescue ActiveRecord::StatementInvalid
          # This is PG7, so we use a more arcane way of doing it.
          begin_db_transaction
          add_column(table_name, "#{column_name}_ar_tmp", type, options)
343
          execute "UPDATE #{table_name} SET #{column_name}_ar_tmp = CAST(#{column_name} AS #{type_to_sql(type, options[:limit], options[:precision], options[:scale])})"
344 345 346 347
          remove_column(table_name, column_name)
          rename_column(table_name, "#{column_name}_ar_tmp", column_name)
          commit_db_transaction
        end
348
        change_column_default(table_name, column_name, options[:default]) unless options[:default].nil?
349
      end
350

351
      def change_column_default(table_name, column_name, default) #:nodoc:
352
        execute "ALTER TABLE #{table_name} ALTER COLUMN #{quote_column_name(column_name)} SET DEFAULT #{quote(default)}"
353
      end
354

355
      def rename_column(table_name, column_name, new_column_name) #:nodoc:
356
        execute "ALTER TABLE #{table_name} RENAME COLUMN #{quote_column_name(column_name)} TO #{quote_column_name(new_column_name)}"
357
      end
358

359
      def remove_index(table_name, options) #:nodoc:
360
        execute "DROP INDEX #{index_name(table_name, options)}"
361
      end
362

363
      def type_to_sql(type, limit = nil, precision = nil, scale = nil) #:nodoc:
364 365 366 367 368 369 370 371 372 373
        return super unless type.to_s == 'integer'

        if limit.nil? || limit == 4
          'integer'
        elsif limit < 4
          'smallint'
        else
          'bigint'
        end
      end
374
      
375 376 377 378
      # SELECT DISTINCT clause for a given set of columns and a given ORDER BY clause.
      #
      # PostgreSQL requires the ORDER BY columns in the select list for distinct queries, and
      # requires that the ORDER BY include the distinct column.
379
      #
380 381 382 383 384 385
      #   distinct("posts.id", "posts.created_at desc")
      def distinct(columns, order_by)
        return "DISTINCT #{columns}" if order_by.blank?

        # construct a clean list of column names from the ORDER BY clause, removing
        # any asc/desc modifiers
386
        order_columns = order_by.split(',').collect { |s| s.split.first }
387
        order_columns.delete_if &:blank?
388
        order_columns = order_columns.zip((0...order_columns.size).to_a).map { |s,i| "#{s} AS alias_#{i}" }
389 390 391 392 393

        # return a DISTINCT ON() clause that's distinct on the columns we want but includes
        # all the required columns for the ORDER BY to work properly
        sql = "DISTINCT ON (#{columns}) #{columns}, "
        sql << order_columns * ', '
394
      end
395 396 397 398 399 400 401 402 403 404 405 406 407 408
      
      # ORDER BY clause for the passed order option.
      # 
      # PostgreSQL does not allow arbitrary ordering when using DISTINCT ON, so we work around this
      # by wrapping the sql as a sub-select and ordering in that query.
      def add_order_by_for_association_limiting!(sql, options)
        return sql if options[:order].blank?
        
        order = options[:order].split(',').collect { |s| s.strip }.reject(&:blank?)
        order.map! { |s| 'DESC' if s =~ /\bdesc$/i }
        order = order.zip((0...order.size).to_a).map { |s,i| "id_list.alias_#{i} #{s}" }.join(', ')
        
        sql.replace "SELECT * FROM (#{sql}) AS id_list ORDER BY #{order}"
      end
409

D
Initial  
David Heinemeier Hansson 已提交
410
      private
411
        BYTEA_COLUMN_TYPE_OID = 17
412
        NUMERIC_COLUMN_TYPE_OID = 1700
413 414
        TIMESTAMPOID = 1114
        TIMESTAMPTZOID = 1184
415

416 417 418 419 420 421 422 423 424
        def configure_connection
          if @config[:encoding]
            execute("SET client_encoding TO '#{@config[:encoding]}'")
          end
          if @config[:min_messages]
            execute("SET client_min_messages TO '#{@config[:min_messages]}'")
          end
        end

425
        def last_insert_id(table, sequence_name)
426
          Integer(select_value("SELECT currval('#{sequence_name}')"))
D
Initial  
David Heinemeier Hansson 已提交
427 428 429
        end

        def select(sql, name = nil)
430
          res = execute(sql, name)
431
          results = res.result
M
Marcel Molina 已提交
432 433 434 435 436 437 438
          rows = []
          if results.length > 0
            fields = res.fields
            results.each do |row|
              hashed_row = {}
              row.each_index do |cel_index|
                column = row[cel_index]
439

440 441 442 443 444
                case res.type(cel_index)
                  when BYTEA_COLUMN_TYPE_OID
                    column = unescape_bytea(column)
                  when TIMESTAMPTZOID, TIMESTAMPOID
                    column = cast_to_time(column)
445 446
                  when NUMERIC_COLUMN_TYPE_OID
                    column = column.to_d if column.respond_to?(:to_d)
M
Marcel Molina 已提交
447
                end
448

M
Marcel Molina 已提交
449 450 451 452 453
                hashed_row[fields[cel_index]] = column
              end
              rows << hashed_row
            end
          end
454
          res.clear
M
Marcel Molina 已提交
455 456 457
          return rows
        end

458
        def escape_bytea(s)
459 460 461 462 463 464 465 466 467 468 469 470 471 472
          if PGconn.respond_to? :escape_bytea
            self.class.send(:define_method, :escape_bytea) do |s|
              PGconn.escape_bytea(s) if s
            end
          else
            self.class.send(:define_method, :escape_bytea) do |s|
              if s
                result = ''
                s.each_byte { |c| result << sprintf('\\\\%03o', c) }
                result
              end
            end
          end
          escape_bytea(s)
473 474 475
        end

        def unescape_bytea(s)
476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503
          if PGconn.respond_to? :unescape_bytea
            self.class.send(:define_method, :unescape_bytea) do |s|
              PGconn.unescape_bytea(s) if s
            end
          else
            self.class.send(:define_method, :unescape_bytea) do |s|
              if s
                result = ''
                i, max = 0, s.size
                while i < max
                  char = s[i]
                  if char == ?\\
                    if s[i+1] == ?\\
                      char = ?\\
                      i += 1
                    else
                      char = s[i+1..i+3].oct
                      i += 3
                    end
                  end
                  result << char
                  i += 1
                end
                result
              end
            end
          end
          unescape_bytea(s)
504
        end
505

506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525
        # Query a table's column names, default values, and types.
        #
        # The underlying query is roughly:
        #  SELECT column.name, column.type, default.value
        #    FROM column LEFT JOIN default
        #      ON column.table_id = default.table_id
        #     AND column.num = default.column_num
        #   WHERE column.table_id = get_table_id('table_name')
        #     AND column.num > 0
        #     AND NOT column.is_dropped
        #   ORDER BY column.num
        #
        # If the table name is not prefixed with a schema, the database will
        # take the first match from the schema search path.
        #
        # Query implementation notes:
        #  - format_type includes the column size constraint, e.g. varchar(50)
        #  - ::regclass is a function that gives the id for a table name
        def column_definitions(table_name)
          query <<-end_sql
526
            SELECT a.attname, format_type(a.atttypid, a.atttypmod), d.adsrc, a.attnotnull
527 528 529 530 531 532
              FROM pg_attribute a LEFT JOIN pg_attrdef d
                ON a.attrelid = d.adrelid AND a.attnum = d.adnum
             WHERE a.attrelid = '#{table_name}'::regclass
               AND a.attnum > 0 AND NOT a.attisdropped
             ORDER BY a.attnum
          end_sql
D
Initial  
David Heinemeier Hansson 已提交
533 534
        end

535 536 537 538 539 540
        # Translate PostgreSQL-specific types into simplified SQL types.
        # These are special cases; standard types are handled by
        # ConnectionAdapters::Column#simplified_type.
        def translate_field_type(field_type)
          # Match the beginning of field_type since it may have a size constraint on the end.
          case field_type
541 542
            # PostgreSQL array data types.
            when /\[\]$/i  then 'string'
543 544 545
            when /^timestamp/i    then 'datetime'
            when /^real|^money/i  then 'float'
            when /^interval/i     then 'string'
546
            # geometric types (the line type is currently not implemented in postgresql)
547
            when /^(?:point|lseg|box|"?path"?|polygon|circle)/i  then 'string'
548 549
            when /^bytea/i        then 'binary'
            else field_type       # Pass through standard types.
D
Initial  
David Heinemeier Hansson 已提交
550 551 552 553 554 555 556
          end
        end

        def default_value(value)
          # Boolean types
          return "t" if value =~ /true/i
          return "f" if value =~ /false/i
557

558 559
          # Char/String/Bytea type values
          return $1 if value =~ /^'(.*)'::(bpchar|text|character varying|bytea)$/
560

D
Initial  
David Heinemeier Hansson 已提交
561
          # Numeric values
562
          return value if value =~ /^-?[0-9]+(\.[0-9]*)?/
D
Initial  
David Heinemeier Hansson 已提交
563 564 565

          # Fixed dates / times
          return $1 if value =~ /^'(.+)'::(date|timestamp)/
566

D
Initial  
David Heinemeier Hansson 已提交
567 568 569 570
          # Anything else is blank, some user type, or some function
          # and we can't know the value of that, so return nil.
          return nil
        end
571 572 573 574 575

        # Only needed for DateTime instances
        def cast_to_time(value)
          return value unless value.class == DateTime
          v = value
576
          time_array = [v.year, v.month, v.day, v.hour, v.min, v.sec, v.usec]
577 578
          Time.send(Base.default_timezone, *time_array) rescue nil
        end
D
Initial  
David Heinemeier Hansson 已提交
579 580 581
    end
  end
end