postgresql_adapter.rb 7.3 KB
Newer Older
D
Initial  
David Heinemeier Hansson 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26

# postgresql_adaptor.rb
# author: Luke Holden <lholden@cablelan.net>
# notes: Currently this adaptor does not pass the test_zero_date_fields
#        and test_zero_datetime_fields unit tests in the BasicsTest test
#        group.
#
#        This is due to the fact that, in postgresql you can not have a
#        totally zero timestamp. Instead null/nil should be used to 
#        represent no value.
#

require 'active_record/connection_adapters/abstract_adapter'
require 'parsedate'

module ActiveRecord
  class Base
    # Establishes a connection to the database that's used by all Active Record objects
    def self.postgresql_connection(config) # :nodoc:
      require_library_or_gem 'postgres' unless self.class.const_defined?(:PGconn)
      symbolize_strings_in_hash(config)
      host     = config[:host]
      port     = config[:port]     || 5432 unless host.nil?
      username = config[:username].to_s
      password = config[:password].to_s

27 28
      schema_order = config[:schema_order]

D
Initial  
David Heinemeier Hansson 已提交
29 30 31 32 33 34
      if config.has_key?(:database)
        database = config[:database]
      else
        raise ArgumentError, "No database specified. Missing argument: database."
      end

35
      pga = ConnectionAdapters::PostgreSQLAdapter.new(
D
Initial  
David Heinemeier Hansson 已提交
36 37
        PGconn.connect(host, port, "", "", database, username, password), logger
      )
38 39 40 41

      pga.execute("SET search_path TO #{schema_order}") if schema_order

      pga
D
Initial  
David Heinemeier Hansson 已提交
42 43 44 45
    end
  end

  module ConnectionAdapters
46 47 48 49 50 51 52 53 54 55 56 57
    # The PostgreSQL adapter works both with the C-based (http://www.postgresql.jp/interfaces/ruby/) and the Ruby-base
    # (available both as gem and from http://rubyforge.org/frs/?group_id=234&release_id=1145) drivers.
    #
    # Options:
    #
    # * <tt>:host</tt> -- Defaults to localhost
    # * <tt>:port</tt> -- Defaults to 5432
    # * <tt>:username</tt> -- Defaults to nothing
    # * <tt>:password</tt> -- Defaults to nothing
    # * <tt>:database</tt> -- The name of the database. No default, must be provided.
    # * <tt>:schema_order</tt> -- An optional schema order string that is using in a SET search_path TO <schema_order> call on connection.
    class PostgreSQLAdapter < AbstractAdapter
D
Initial  
David Heinemeier Hansson 已提交
58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83
      def select_all(sql, name = nil)
        select(sql, name)
      end

      def select_one(sql, name = nil)
        result = select(sql, name)
        result.nil? ? nil : result.first
      end

      def columns(table_name, name = nil)
        table_structure(table_name).inject([]) do |columns, field|
          columns << Column.new(field[0], field[2], field[1])
          columns
        end
      end

      def insert(sql, name = nil, pk = nil, id_value = nil)
        execute(sql, name = nil)
        table = sql.split(" ", 4)[2]
        return id_value || last_insert_id(table, pk)
      end

      def execute(sql, name = nil)
        log(sql, name, @connection) { |connection| connection.query(sql) }
      end

84 85 86 87 88 89 90
      def update(sql, name = nil)
        result = nil
        log(sql, name, @connection) { |connection| result = connection.exec(sql) }
        result.cmdtuples
      end

      alias_method :delete, :update
D
Initial  
David Heinemeier Hansson 已提交
91 92 93 94 95

      def begin_db_transaction()    execute "BEGIN" end
      def commit_db_transaction()   execute "COMMIT" end
      def rollback_db_transaction() execute "ROLLBACK" end

96 97 98 99 100 101 102 103
      def quote(value, column = nil)
        if value.class == String && column && column.type == :binary
          quote_bytea(value)
        else
          super
        end
      end

D
Initial  
David Heinemeier Hansson 已提交
104 105 106 107
      def quote_column_name(name)
        return "\"#{name}\""
      end

108 109 110 111
      def adapter_name()
        'PostgreSQL'
      end

D
Initial  
David Heinemeier Hansson 已提交
112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127
      private
        def last_insert_id(table, column = "id")
          sequence_name = "#{table}_#{column || 'id'}_seq"
          @connection.exec("SELECT currval('#{sequence_name}')")[0][0].to_i
        end

        def select(sql, name = nil)
          res = nil
          log(sql, name, @connection) { |connection| res = connection.exec(sql) }
  
          results = res.result           
          rows = []
          if results.length > 0
            fields = res.fields
            results.each do |row|
              hashed_row = {}
128 129 130 131 132 133 134
              row.each_index do |cel_index|
                column = row[cel_index]
                if res.type(cel_index) == 17  # type oid for bytea
                  column = unescape_bytea(column)
                end
                hashed_row[fields[cel_index]] = column
              end
D
Initial  
David Heinemeier Hansson 已提交
135 136 137 138 139 140
              rows << hashed_row
            end
          end
          return rows
        end

141 142 143 144 145
        def quote_bytea(s)
          "'#{escape_bytea(s)}'"
        end

        def escape_bytea(s)
146
          s.gsub(/\\/) { '\\\\\\\\' }.gsub(/[^\\]/) { |c| sprintf('\\\\%03o', c[0].to_i) } unless s.nil?
147 148 149
        end

        def unescape_bytea(s)
150
          s.gsub(/\\([0-9][0-9][0-9])/) { $1.oct.chr }.gsub(/\\\\/) { '\\' } unless s.nil?
151 152
        end

D
Initial  
David Heinemeier Hansson 已提交
153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190
        def split_table_schema(table_name)
          schema_split = table_name.split('.')
          schema_name = "public"
          if schema_split.length > 1
            schema_name = schema_split.first.strip
            table_name = schema_split.last.strip
          end
          return [schema_name, table_name]
        end

        def table_structure(table_name)
          database_name = @connection.db
          schema_name, table_name = split_table_schema(table_name)
          
          # Grab a list of all the default values for the columns.
          sql =  "SELECT column_name, column_default, character_maximum_length, data_type "
          sql << "  FROM information_schema.columns "
          sql << " WHERE table_catalog = '#{database_name}' "
          sql << "   AND table_schema = '#{schema_name}' "
          sql << "   AND table_name = '#{table_name}';"

          column_defaults = nil
          log(sql, nil, @connection) { |connection| column_defaults = connection.query(sql) }
          column_defaults.collect do |row|
              field   = row[0]
              type    = type_as_string(row[3], row[2])
              default = default_value(row[1])
              length  = row[2]

              [field, type, default, length]
          end
        end

        def type_as_string(field_type, field_length)
          type = case field_type
            when 'numeric', 'real', 'money'      then 'float'
            when 'character varying', 'interval' then 'string'
            when 'timestamp without time zone'   then 'datetime'
191
            when 'timestamp with time zone'      then 'datetime'
192
            when 'bytea'                         then 'binary'
D
Initial  
David Heinemeier Hansson 已提交
193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224
            else field_type
          end

          size = field_length.nil? ? "" : "(#{field_length})"

          return type + size
        end

        def default_value(value)
          # Boolean types
          return "t" if value =~ /true/i
          return "f" if value =~ /false/i
          
          # Char/String type values
          return $1 if value =~ /^'(.*)'::(bpchar|text|character varying)$/
          
          # Numeric values
          return value if value =~ /^[0-9]+(\.[0-9]*)?/

          # Date / Time magic values
          return Time.now.to_s if value =~ /^\('now'::text\)::(date|timestamp)/

          # Fixed dates / times
          return $1 if value =~ /^'(.+)'::(date|timestamp)/
          
          # Anything else is blank, some user type, or some function
          # and we can't know the value of that, so return nil.
          return nil
        end
    end
  end
end