提交 192b7bcf 编写于 作者: G Gannon McGibbon

Redact SQL in errors

Move `ActiveRecord::StatementInvalid` SQL to error property.
Also add bindings as an error property.
上级 47ab6b36
* Move `ActiveRecord::StatementInvalid` SQL to error property and include binds as separate error property.
`ActiveRecord::ConnectionAdapters::AbstractAdapter#translate_exception_class` now requires `binds` to be passed as the last argument.
`ActiveRecord::ConnectionAdapters::AbstractAdapter#translate_exception` now requires `message`, `sql`, and `binds` to be passed as keyword arguments.
Subclasses of `ActiveRecord::StatementInvalid` must now provide `sql:` and `binds:` arguments to `super`.
Example:
```
class MySubclassedError < ActiveRecord::StatementInvalid
def initialize(message, sql:, binds:)
super(message, sql: sql, binds: binds)
end
end
```
*Gannon McGibbon*
* Add an `:if_not_exists` option to `create_table`. * Add an `:if_not_exists` option to `create_table`.
Example: Example:
......
...@@ -580,14 +580,12 @@ def extract_limit(sql_type) ...@@ -580,14 +580,12 @@ def extract_limit(sql_type)
$1.to_i if sql_type =~ /\((.*)\)/ $1.to_i if sql_type =~ /\((.*)\)/
end end
def translate_exception_class(e, sql) def translate_exception_class(e, sql, binds)
begin message = "#{e.class.name}: #{e.message}"
message = "#{e.class.name}: #{e.message}: #{sql}"
rescue Encoding::CompatibilityError
message = "#{e.class.name}: #{e.message.force_encoding sql.encoding}: #{sql}"
end
exception = translate_exception(e, message) exception = translate_exception(
e, message: message, sql: sql, binds: binds
)
exception.set_backtrace e.backtrace exception.set_backtrace e.backtrace
exception exception
end end
...@@ -606,18 +604,18 @@ def log(sql, name = "SQL", binds = [], type_casted_binds = [], statement_name = ...@@ -606,18 +604,18 @@ def log(sql, name = "SQL", binds = [], type_casted_binds = [], statement_name =
yield yield
end end
rescue => e rescue => e
raise translate_exception_class(e, sql) raise translate_exception_class(e, sql, binds)
end end
end end
end end
def translate_exception(exception, message) def translate_exception(exception, message:, sql:, binds:)
# override in derived class # override in derived class
case exception case exception
when RuntimeError when RuntimeError
exception exception
else else
ActiveRecord::StatementInvalid.new(message) ActiveRecord::StatementInvalid.new(message, sql: sql, binds: binds)
end end
end end
......
...@@ -642,34 +642,34 @@ def extract_precision(sql_type) ...@@ -642,34 +642,34 @@ def extract_precision(sql_type)
ER_QUERY_INTERRUPTED = 1317 ER_QUERY_INTERRUPTED = 1317
ER_QUERY_TIMEOUT = 3024 ER_QUERY_TIMEOUT = 3024
def translate_exception(exception, message) def translate_exception(exception, message:, sql:, binds:)
case error_number(exception) case error_number(exception)
when ER_DUP_ENTRY when ER_DUP_ENTRY
RecordNotUnique.new(message) RecordNotUnique.new(message, sql: sql, binds: binds)
when ER_NO_REFERENCED_ROW, ER_ROW_IS_REFERENCED, ER_ROW_IS_REFERENCED_2, ER_NO_REFERENCED_ROW_2 when ER_NO_REFERENCED_ROW, ER_ROW_IS_REFERENCED, ER_ROW_IS_REFERENCED_2, ER_NO_REFERENCED_ROW_2
InvalidForeignKey.new(message) InvalidForeignKey.new(message, sql: sql, binds: binds)
when ER_CANNOT_ADD_FOREIGN when ER_CANNOT_ADD_FOREIGN
mismatched_foreign_key(message) mismatched_foreign_key(message, sql: sql, binds: binds)
when ER_CANNOT_CREATE_TABLE when ER_CANNOT_CREATE_TABLE
if message.include?("errno: 150") if message.include?("errno: 150")
mismatched_foreign_key(message) mismatched_foreign_key(message, sql: sql, binds: binds)
else else
super super
end end
when ER_DATA_TOO_LONG when ER_DATA_TOO_LONG
ValueTooLong.new(message) ValueTooLong.new(message, sql: sql, binds: binds)
when ER_OUT_OF_RANGE when ER_OUT_OF_RANGE
RangeError.new(message) RangeError.new(message, sql: sql, binds: binds)
when ER_NOT_NULL_VIOLATION, ER_DO_NOT_HAVE_DEFAULT when ER_NOT_NULL_VIOLATION, ER_DO_NOT_HAVE_DEFAULT
NotNullViolation.new(message) NotNullViolation.new(message, sql: sql, binds: binds)
when ER_LOCK_DEADLOCK when ER_LOCK_DEADLOCK
Deadlocked.new(message) Deadlocked.new(message, sql: sql, binds: binds)
when ER_LOCK_WAIT_TIMEOUT when ER_LOCK_WAIT_TIMEOUT
LockWaitTimeout.new(message) LockWaitTimeout.new(message, sql: sql, binds: binds)
when ER_QUERY_TIMEOUT when ER_QUERY_TIMEOUT
StatementTimeout.new(message) StatementTimeout.new(message, sql: sql, binds: binds)
when ER_QUERY_INTERRUPTED when ER_QUERY_INTERRUPTED
QueryCanceled.new(message) QueryCanceled.new(message, sql: sql, binds: binds)
else else
super super
end end
...@@ -800,11 +800,13 @@ def arel_visitor ...@@ -800,11 +800,13 @@ def arel_visitor
Arel::Visitors::MySQL.new(self) Arel::Visitors::MySQL.new(self)
end end
def mismatched_foreign_key(message) def mismatched_foreign_key(message, sql:, binds:)
parts = message.scan(/`(\w+)`[ $)]/).flatten parts = sql.scan(/`(\w+)`[ $)]/).flatten
MismatchedForeignKey.new( MismatchedForeignKey.new(
self, self,
message: message, message: message,
sql: sql,
binds: binds,
table: parts[0], table: parts[0],
foreign_key: parts[1], foreign_key: parts[1],
target_table: parts[2], target_table: parts[2],
......
...@@ -441,28 +441,28 @@ def check_version ...@@ -441,28 +441,28 @@ def check_version
LOCK_NOT_AVAILABLE = "55P03" LOCK_NOT_AVAILABLE = "55P03"
QUERY_CANCELED = "57014" QUERY_CANCELED = "57014"
def translate_exception(exception, message) def translate_exception(exception, message:, sql:, binds:)
return exception unless exception.respond_to?(:result) return exception unless exception.respond_to?(:result)
case exception.result.try(:error_field, PG::PG_DIAG_SQLSTATE) case exception.result.try(:error_field, PG::PG_DIAG_SQLSTATE)
when UNIQUE_VIOLATION when UNIQUE_VIOLATION
RecordNotUnique.new(message) RecordNotUnique.new(message, sql: sql, binds: binds)
when FOREIGN_KEY_VIOLATION when FOREIGN_KEY_VIOLATION
InvalidForeignKey.new(message) InvalidForeignKey.new(message, sql: sql, binds: binds)
when VALUE_LIMIT_VIOLATION when VALUE_LIMIT_VIOLATION
ValueTooLong.new(message) ValueTooLong.new(message, sql: sql, binds: binds)
when NUMERIC_VALUE_OUT_OF_RANGE when NUMERIC_VALUE_OUT_OF_RANGE
RangeError.new(message) RangeError.new(message, sql: sql, binds: binds)
when NOT_NULL_VIOLATION when NOT_NULL_VIOLATION
NotNullViolation.new(message) NotNullViolation.new(message, sql: sql, binds: binds)
when SERIALIZATION_FAILURE when SERIALIZATION_FAILURE
SerializationFailure.new(message) SerializationFailure.new(message, sql: sql, binds: binds)
when DEADLOCK_DETECTED when DEADLOCK_DETECTED
Deadlocked.new(message) Deadlocked.new(message, sql: sql, binds: binds)
when LOCK_NOT_AVAILABLE when LOCK_NOT_AVAILABLE
LockWaitTimeout.new(message) LockWaitTimeout.new(message, sql: sql, binds: binds)
when QUERY_CANCELED when QUERY_CANCELED
QueryCanceled.new(message) QueryCanceled.new(message, sql: sql, binds: binds)
else else
super super
end end
...@@ -642,7 +642,7 @@ def exec_no_cache(sql, name, binds) ...@@ -642,7 +642,7 @@ def exec_no_cache(sql, name, binds)
def exec_cache(sql, name, binds) def exec_cache(sql, name, binds)
materialize_transactions materialize_transactions
stmt_key = prepare_statement(sql) stmt_key = prepare_statement(sql, binds)
type_casted_binds = type_casted_binds(binds) type_casted_binds = type_casted_binds(binds)
log(sql, name, binds, type_casted_binds, stmt_key) do log(sql, name, binds, type_casted_binds, stmt_key) do
...@@ -696,7 +696,7 @@ def sql_key(sql) ...@@ -696,7 +696,7 @@ def sql_key(sql)
# Prepare the statement if it hasn't been prepared, return # Prepare the statement if it hasn't been prepared, return
# the statement key. # the statement key.
def prepare_statement(sql) def prepare_statement(sql, binds)
@lock.synchronize do @lock.synchronize do
sql_key = sql_key(sql) sql_key = sql_key(sql)
unless @statements.key? sql_key unless @statements.key? sql_key
...@@ -704,7 +704,7 @@ def prepare_statement(sql) ...@@ -704,7 +704,7 @@ def prepare_statement(sql)
begin begin
@connection.prepare nextkey, sql @connection.prepare nextkey, sql
rescue => e rescue => e
raise translate_exception_class(e, sql) raise translate_exception_class(e, sql, binds)
end end
# Clear the queue # Clear the queue
@connection.get_last_result @connection.get_last_result
......
...@@ -529,18 +529,18 @@ def sqlite_version ...@@ -529,18 +529,18 @@ def sqlite_version
@sqlite_version ||= SQLite3Adapter::Version.new(query_value("SELECT sqlite_version(*)")) @sqlite_version ||= SQLite3Adapter::Version.new(query_value("SELECT sqlite_version(*)"))
end end
def translate_exception(exception, message) def translate_exception(exception, message:, sql:, binds:)
case exception.message case exception.message
# SQLite 3.8.2 returns a newly formatted error message: # SQLite 3.8.2 returns a newly formatted error message:
# UNIQUE constraint failed: *table_name*.*column_name* # UNIQUE constraint failed: *table_name*.*column_name*
# Older versions of SQLite return: # Older versions of SQLite return:
# column *column_name* is not unique # column *column_name* is not unique
when /column(s)? .* (is|are) not unique/, /UNIQUE constraint failed: .*/ when /column(s)? .* (is|are) not unique/, /UNIQUE constraint failed: .*/
RecordNotUnique.new(message) RecordNotUnique.new(message, sql: sql, binds: binds)
when /.* may not be NULL/, /NOT NULL constraint failed: .*/ when /.* may not be NULL/, /NOT NULL constraint failed: .*/
NotNullViolation.new(message) NotNullViolation.new(message, sql: sql, binds: binds)
when /FOREIGN KEY constraint failed/i when /FOREIGN KEY constraint failed/i
InvalidForeignKey.new(message) InvalidForeignKey.new(message, sql: sql, binds: binds)
else else
super super
end end
......
...@@ -97,9 +97,13 @@ def initialize(message = nil, record = nil) ...@@ -97,9 +97,13 @@ def initialize(message = nil, record = nil)
# #
# Wraps the underlying database error as +cause+. # Wraps the underlying database error as +cause+.
class StatementInvalid < ActiveRecordError class StatementInvalid < ActiveRecordError
def initialize(message = nil) def initialize(message = nil, sql: nil, binds: nil)
super(message || $!.try(:message)) super(message || $!.try(:message))
@sql = sql
@binds = binds
end end
attr_reader :sql, :binds
end end
# Defunct wrapper class kept for compatibility. # Defunct wrapper class kept for compatibility.
...@@ -118,7 +122,7 @@ class InvalidForeignKey < WrappedDatabaseException ...@@ -118,7 +122,7 @@ class InvalidForeignKey < WrappedDatabaseException
# Raised when a foreign key constraint cannot be added because the column type does not match the referenced column type. # Raised when a foreign key constraint cannot be added because the column type does not match the referenced column type.
class MismatchedForeignKey < StatementInvalid class MismatchedForeignKey < StatementInvalid
def initialize(adapter = nil, message: nil, table: nil, foreign_key: nil, target_table: nil, primary_key: nil) def initialize(adapter = nil, message: nil, sql: nil, binds: nil, table: nil, foreign_key: nil, target_table: nil, primary_key: nil)
@adapter = adapter @adapter = adapter
if table if table
msg = +<<~EOM msg = +<<~EOM
...@@ -135,7 +139,7 @@ def initialize(adapter = nil, message: nil, table: nil, foreign_key: nil, target ...@@ -135,7 +139,7 @@ def initialize(adapter = nil, message: nil, table: nil, foreign_key: nil, target
if message if message
msg << "\nOriginal message: #{message}" msg << "\nOriginal message: #{message}"
end end
super(msg) super(msg, sql: sql, binds: binds)
end end
private private
......
...@@ -286,18 +286,6 @@ def test_select_methods_passing_a_relation ...@@ -286,18 +286,6 @@ def test_select_methods_passing_a_relation
assert_equal "special_db_type", @connection.type_to_sql(:special_db_type) assert_equal "special_db_type", @connection.type_to_sql(:special_db_type)
end end
unless current_adapter?(:PostgreSQLAdapter)
def test_log_invalid_encoding
error = assert_raises RuntimeError do
@connection.send :log, "SELECT 'ы' FROM DUAL" do
raise (+"ы").force_encoding(Encoding::ASCII_8BIT)
end
end
assert_equal "ы", error.message
end
end
def test_supports_multi_insert_is_deprecated def test_supports_multi_insert_is_deprecated
assert_deprecated { @connection.supports_multi_insert? } assert_deprecated { @connection.supports_multi_insert? }
end end
......
...@@ -218,8 +218,8 @@ def test_count_on_invalid_columns_raises ...@@ -218,8 +218,8 @@ def test_count_on_invalid_columns_raises
Account.select("credit_limit, firm_name").count Account.select("credit_limit, firm_name").count
} }
assert_match %r{accounts}i, e.message assert_match %r{accounts}i, e.sql
assert_match "credit_limit, firm_name", e.message assert_match "credit_limit, firm_name", e.sql
end end
def test_apply_distinct_in_count def test_apply_distinct_in_count
......
# frozen_string_literal: true
require "cases/helper"
require "models/book"
module ActiveRecord
class StatementInvalidTest < ActiveRecord::TestCase
fixtures :books
class MockDatabaseError < StandardError
def result
0
end
def error_number
0
end
end
test "message contains no sql" do
sql = Book.where(author_id: 96, cover: "hard").to_sql
error = assert_raises(ActiveRecord::StatementInvalid) do
Book.connection.send(:log, sql, Book.name) do
raise MockDatabaseError
end
end
assert_not error.message.include?("SELECT")
end
test "statement and binds are set on select" do
sql = Book.where(author_id: 96, cover: "hard").to_sql
binds = [Minitest::Mock.new, Minitest::Mock.new]
error = assert_raises(ActiveRecord::StatementInvalid) do
Book.connection.send(:log, sql, Book.name, binds) do
raise MockDatabaseError
end
end
assert_equal error.sql, sql
assert_equal error.binds, binds
end
end
end
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册