提交 3b50a7a3 编写于 作者: S Sean Griffin

Partially merge #17650

Merges 1d8d5a74. The pull request as a
whole is quite large, and I'm reviewing the smaller pieces individually.
...@@ -19,7 +19,7 @@ def run(records) ...@@ -19,7 +19,7 @@ def run(records)
enums, nodes = nodes.partition { |row| row['typtype'] == 'e' } enums, nodes = nodes.partition { |row| row['typtype'] == 'e' }
domains, nodes = nodes.partition { |row| row['typtype'] == 'd' } domains, nodes = nodes.partition { |row| row['typtype'] == 'd' }
arrays, nodes = nodes.partition { |row| row['typinput'] == 'array_in' } arrays, nodes = nodes.partition { |row| row['typinput'] == 'array_in' }
composites, nodes = nodes.partition { |row| row['typelem'] != '0' } composites, nodes = nodes.partition { |row| row['typelem'].to_i != 0 }
mapped.each { |row| register_mapped_type(row) } mapped.each { |row| register_mapped_type(row) }
enums.each { |row| register_enum_type(row) } enums.each { |row| register_enum_type(row) }
......
...@@ -129,8 +129,8 @@ def indexes(table_name, name = nil) ...@@ -129,8 +129,8 @@ def indexes(table_name, name = nil)
result.map do |row| result.map do |row|
index_name = row[0] index_name = row[0]
unique = row[1] == 't' unique = row[1]
indkey = row[2].split(" ") indkey = row[2].split(" ").map(&:to_i)
inddef = row[3] inddef = row[3]
oid = row[4] oid = row[4]
...@@ -164,7 +164,7 @@ def columns(table_name) ...@@ -164,7 +164,7 @@ def columns(table_name)
type_metadata = fetch_type_metadata(column_name, type, oid, fmod) type_metadata = fetch_type_metadata(column_name, type, oid, fmod)
default_value = extract_value_from_default(default) default_value = extract_value_from_default(default)
default_function = extract_default_function(default_value, default) default_function = extract_default_function(default_value, default)
new_column(column_name, default_value, type_metadata, notnull == 'f', default_function) new_column(column_name, default_value, type_metadata, !notnull, default_function)
end end
end end
......
...@@ -278,8 +278,6 @@ def initialize(connection, logger, connection_parameters, config) ...@@ -278,8 +278,6 @@ def initialize(connection, logger, connection_parameters, config)
@table_alias_length = nil @table_alias_length = nil
connect connect
add_pg_decoders
@statements = StatementPool.new @connection, @statements = StatementPool.new @connection,
self.class.type_cast_config_to_integer(config.fetch(:statement_limit) { 1000 }) self.class.type_cast_config_to_integer(config.fetch(:statement_limit) { 1000 })
...@@ -287,6 +285,8 @@ def initialize(connection, logger, connection_parameters, config) ...@@ -287,6 +285,8 @@ def initialize(connection, logger, connection_parameters, config)
raise "Your version of PostgreSQL (#{postgresql_version}) is too old, please upgrade!" raise "Your version of PostgreSQL (#{postgresql_version}) is too old, please upgrade!"
end end
add_pg_decoders
@type_map = Type::HashLookupTypeMap.new @type_map = Type::HashLookupTypeMap.new
initialize_type_map(type_map) initialize_type_map(type_map)
@local_tz = execute('SHOW TIME ZONE', 'SCHEMA').first["TimeZone"] @local_tz = execute('SHOW TIME ZONE', 'SCHEMA').first["TimeZone"]
...@@ -798,7 +798,7 @@ def can_perform_case_insensitive_comparison_for?(column) ...@@ -798,7 +798,7 @@ def can_perform_case_insensitive_comparison_for?(column)
) )
end_sql end_sql
execute_and_clear(sql, "SCHEMA", []) do |result| execute_and_clear(sql, "SCHEMA", []) do |result|
result.getvalue(0, 0) == 't' result.getvalue(0, 0)
end end
end end
end end
...@@ -814,12 +814,12 @@ def add_pg_decoders ...@@ -814,12 +814,12 @@ def add_pg_decoders
'bool' => PG::TextDecoder::Boolean, 'bool' => PG::TextDecoder::Boolean,
} }
query = <<-SQL query = <<-SQL
SELECT t.oid, t.typname, t.typelem, t.typdelim, t.typinput, t.typtype, t.typbasetype SELECT t.oid, t.typname
FROM pg_type as t FROM pg_type as t
SQL SQL
coders = execute_and_clear(query, "SCHEMA", []) do |result| coders = execute_and_clear(query, "SCHEMA", []) do |result|
result result
.map { |row| construct_coder(row, coders_by_name['typname']) } .map { |row| construct_coder(row, coders_by_name[row['typname']]) }
.compact .compact
end end
...@@ -830,7 +830,7 @@ def add_pg_decoders ...@@ -830,7 +830,7 @@ def add_pg_decoders
def construct_coder(row, coder_class) def construct_coder(row, coder_class)
return unless coder_class return unless coder_class
coder_class.new(oid: row['oid'], name: row['typname']) coder_class.new(oid: row['oid'].to_i, name: row['typname'])
end end
ActiveRecord::Type.add_modifier({ array: true }, OID::Array, adapter: :postgresql) ActiveRecord::Type.add_modifier({ array: true }, OID::Array, adapter: :postgresql)
......
...@@ -31,7 +31,7 @@ def test_values ...@@ -31,7 +31,7 @@ def test_values
assert_equal 123456.789, first.double assert_equal 123456.789, first.double
assert_equal(-::Float::INFINITY, second.single) assert_equal(-::Float::INFINITY, second.single)
assert_equal ::Float::INFINITY, second.double assert_equal ::Float::INFINITY, second.double
assert_same ::Float::NAN, third.double assert_send [third.double, :nan?]
end end
def test_update def test_update
......
...@@ -68,7 +68,7 @@ def test_primary_key_raises_error_if_table_not_found ...@@ -68,7 +68,7 @@ def test_primary_key_raises_error_if_table_not_found
def test_insert_sql_with_proprietary_returning_clause def test_insert_sql_with_proprietary_returning_clause
with_example_table do with_example_table do
id = @connection.insert_sql("insert into ex (number) values(5150)", nil, "number") id = @connection.insert_sql("insert into ex (number) values(5150)", nil, "number")
assert_equal "5150", id assert_equal 5150, id
end end
end end
...@@ -106,21 +106,21 @@ def test_insert_sql_with_returning_disabled ...@@ -106,21 +106,21 @@ def test_insert_sql_with_returning_disabled
connection = connection_without_insert_returning connection = connection_without_insert_returning
id = connection.insert_sql("insert into postgresql_partitioned_table_parent (number) VALUES (1)") id = connection.insert_sql("insert into postgresql_partitioned_table_parent (number) VALUES (1)")
expect = connection.query('select max(id) from postgresql_partitioned_table_parent').first.first expect = connection.query('select max(id) from postgresql_partitioned_table_parent').first.first
assert_equal expect, id assert_equal expect.to_i, id
end end
def test_exec_insert_with_returning_disabled def test_exec_insert_with_returning_disabled
connection = connection_without_insert_returning connection = connection_without_insert_returning
result = connection.exec_insert("insert into postgresql_partitioned_table_parent (number) VALUES (1)", nil, [], 'id', 'postgresql_partitioned_table_parent_id_seq') result = connection.exec_insert("insert into postgresql_partitioned_table_parent (number) VALUES (1)", nil, [], 'id', 'postgresql_partitioned_table_parent_id_seq')
expect = connection.query('select max(id) from postgresql_partitioned_table_parent').first.first expect = connection.query('select max(id) from postgresql_partitioned_table_parent').first.first
assert_equal expect, result.rows.first.first assert_equal expect.to_i, result.rows.first.first
end end
def test_exec_insert_with_returning_disabled_and_no_sequence_name_given def test_exec_insert_with_returning_disabled_and_no_sequence_name_given
connection = connection_without_insert_returning connection = connection_without_insert_returning
result = connection.exec_insert("insert into postgresql_partitioned_table_parent (number) VALUES (1)", nil, [], 'id') result = connection.exec_insert("insert into postgresql_partitioned_table_parent (number) VALUES (1)", nil, [], 'id')
expect = connection.query('select max(id) from postgresql_partitioned_table_parent').first.first expect = connection.query('select max(id) from postgresql_partitioned_table_parent').first.first
assert_equal expect, result.rows.first.first assert_equal expect.to_i, result.rows.first.first
end end
def test_sql_for_insert_with_returning_disabled def test_sql_for_insert_with_returning_disabled
...@@ -238,7 +238,7 @@ def test_exec_insert_number ...@@ -238,7 +238,7 @@ def test_exec_insert_number
result = @connection.exec_query('SELECT number FROM ex WHERE number = 10') result = @connection.exec_query('SELECT number FROM ex WHERE number = 10')
assert_equal 1, result.rows.length assert_equal 1, result.rows.length
assert_equal "10", result.rows.last.last assert_equal 10, result.rows.last.last
end end
end end
...@@ -274,7 +274,7 @@ def test_exec_no_binds ...@@ -274,7 +274,7 @@ def test_exec_no_binds
assert_equal 1, result.rows.length assert_equal 1, result.rows.length
assert_equal 2, result.columns.length assert_equal 2, result.columns.length
assert_equal [['1', 'foo']], result.rows assert_equal [[1, 'foo']], result.rows
end end
end end
...@@ -288,7 +288,7 @@ def test_exec_with_binds ...@@ -288,7 +288,7 @@ def test_exec_with_binds
assert_equal 1, result.rows.length assert_equal 1, result.rows.length
assert_equal 2, result.columns.length assert_equal 2, result.columns.length
assert_equal [['1', 'foo']], result.rows assert_equal [[1, 'foo']], result.rows
end end
end end
...@@ -304,7 +304,7 @@ def test_exec_typecasts_bind_vals ...@@ -304,7 +304,7 @@ def test_exec_typecasts_bind_vals
assert_equal 1, result.rows.length assert_equal 1, result.rows.length
assert_equal 2, result.columns.length assert_equal 2, result.columns.length
assert_equal [['1', 'foo']], result.rows assert_equal [[1, 'foo']], result.rows
end end
end end
......
...@@ -106,6 +106,6 @@ def test_only_catch_active_record_errors_others_bubble_up ...@@ -106,6 +106,6 @@ def test_only_catch_active_record_errors_others_bubble_up
private private
def assert_transaction_is_not_broken def assert_transaction_is_not_broken
assert_equal "1", @connection.select_value("SELECT 1") assert_equal 1, @connection.select_value("SELECT 1")
end end
end end
...@@ -384,16 +384,16 @@ def test_schema_exists? ...@@ -384,16 +384,16 @@ def test_schema_exists?
def test_reset_pk_sequence def test_reset_pk_sequence
sequence_name = "#{SCHEMA_NAME}.#{UNMATCHED_SEQUENCE_NAME}" sequence_name = "#{SCHEMA_NAME}.#{UNMATCHED_SEQUENCE_NAME}"
@connection.execute "SELECT setval('#{sequence_name}', 123)" @connection.execute "SELECT setval('#{sequence_name}', 123)"
assert_equal "124", @connection.select_value("SELECT nextval('#{sequence_name}')") assert_equal 124, @connection.select_value("SELECT nextval('#{sequence_name}')")
@connection.reset_pk_sequence!("#{SCHEMA_NAME}.#{UNMATCHED_PK_TABLE_NAME}") @connection.reset_pk_sequence!("#{SCHEMA_NAME}.#{UNMATCHED_PK_TABLE_NAME}")
assert_equal "1", @connection.select_value("SELECT nextval('#{sequence_name}')") assert_equal 1, @connection.select_value("SELECT nextval('#{sequence_name}')")
end end
def test_set_pk_sequence def test_set_pk_sequence
table_name = "#{SCHEMA_NAME}.#{PK_TABLE_NAME}" table_name = "#{SCHEMA_NAME}.#{PK_TABLE_NAME}"
_, sequence_name = @connection.pk_and_sequence_for table_name _, sequence_name = @connection.pk_and_sequence_for table_name
@connection.set_pk_sequence! table_name, 123 @connection.set_pk_sequence! table_name, 123
assert_equal "124", @connection.select_value("SELECT nextval('#{sequence_name}')") assert_equal 124, @connection.select_value("SELECT nextval('#{sequence_name}')")
@connection.reset_pk_sequence! table_name @connection.reset_pk_sequence! table_name
end end
......
...@@ -184,7 +184,7 @@ def test_cache_does_not_wrap_string_results_in_arrays ...@@ -184,7 +184,7 @@ def test_cache_does_not_wrap_string_results_in_arrays
# Oracle adapter returns count() as Fixnum or Float # Oracle adapter returns count() as Fixnum or Float
if current_adapter?(:OracleAdapter) if current_adapter?(:OracleAdapter)
assert_kind_of Numeric, Task.connection.select_value("SELECT count(*) AS count_all FROM tasks") assert_kind_of Numeric, Task.connection.select_value("SELECT count(*) AS count_all FROM tasks")
elsif current_adapter?(:SQLite3Adapter, :Mysql2Adapter) elsif current_adapter?(:SQLite3Adapter, :Mysql2Adapter, :PostgreSQLAdapter)
# Future versions of the sqlite3 adapter will return numeric # Future versions of the sqlite3 adapter will return numeric
assert_instance_of Fixnum, assert_instance_of Fixnum,
Task.connection.select_value("SELECT count(*) AS count_all FROM tasks") Task.connection.select_value("SELECT count(*) AS count_all FROM tasks")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册