Merge pull request #51139 from Shopify/relation-bound-sql-literal
Relation#where build BoundSqlLiteral rather than eagerly interpolate
This commit is contained in:
commit
684131a4f0
@ -21,9 +21,6 @@ def cast_bound_value(value)
|
||||
"1"
|
||||
when false
|
||||
"0"
|
||||
when ActiveSupport::Duration
|
||||
warn_quote_duration_deprecated
|
||||
value.to_s
|
||||
else
|
||||
value
|
||||
end
|
||||
|
@ -141,6 +141,8 @@ def type_cast(value) # :nodoc:
|
||||
encode_array(value)
|
||||
when Range
|
||||
encode_range(value)
|
||||
when Rational
|
||||
value.to_f
|
||||
else
|
||||
super
|
||||
end
|
||||
|
@ -63,7 +63,7 @@ def quote_default_expression(value, column) # :nodoc:
|
||||
|
||||
def type_cast(value) # :nodoc:
|
||||
case value
|
||||
when BigDecimal
|
||||
when BigDecimal, Rational
|
||||
value.to_f
|
||||
when String
|
||||
if value.encoding == Encoding::ASCII_8BIT
|
||||
|
@ -1512,9 +1512,21 @@ def build_subquery(subquery_alias, select_value) # :nodoc:
|
||||
def build_where_clause(opts, rest = []) # :nodoc:
|
||||
opts = sanitize_forbidden_attributes(opts)
|
||||
|
||||
if opts.is_a?(Array)
|
||||
opts, *rest = opts
|
||||
end
|
||||
|
||||
case opts
|
||||
when String, Array
|
||||
parts = [klass.sanitize_sql(rest.empty? ? opts : [opts, *rest])]
|
||||
when String
|
||||
if rest.empty?
|
||||
parts = [Arel.sql(opts)]
|
||||
elsif rest.first.is_a?(Hash) && /:\w+/.match?(opts)
|
||||
parts = [build_named_bound_sql_literal(opts, rest.first)]
|
||||
elsif opts.include?("?")
|
||||
parts = [build_bound_sql_literal(opts, rest)]
|
||||
else
|
||||
parts = [klass.sanitize_sql(rest.empty? ? opts : [opts, *rest])]
|
||||
end
|
||||
when Hash
|
||||
opts = opts.transform_keys do |key|
|
||||
if key.is_a?(Array)
|
||||
@ -1550,6 +1562,46 @@ def async
|
||||
spawn.async!
|
||||
end
|
||||
|
||||
def build_named_bound_sql_literal(statement, values)
|
||||
bound_values = values.transform_values do |value|
|
||||
if ActiveRecord::Relation === value
|
||||
Arel.sql(value.to_sql)
|
||||
elsif value.respond_to?(:map) && !value.acts_like?(:string)
|
||||
values = value.map { |v| v.respond_to?(:id_for_database) ? v.id_for_database : v }
|
||||
values.empty? ? nil : values
|
||||
else
|
||||
value = value.id_for_database if value.respond_to?(:id_for_database)
|
||||
value
|
||||
end
|
||||
end
|
||||
|
||||
begin
|
||||
Arel::Nodes::BoundSqlLiteral.new("(#{statement})", nil, bound_values)
|
||||
rescue Arel::BindError => error
|
||||
raise ActiveRecord::PreparedStatementInvalid, error.message
|
||||
end
|
||||
end
|
||||
|
||||
def build_bound_sql_literal(statement, values)
|
||||
bound_values = values.map do |value|
|
||||
if ActiveRecord::Relation === value
|
||||
Arel.sql(value.to_sql)
|
||||
elsif value.respond_to?(:map) && !value.acts_like?(:string)
|
||||
values = value.map { |v| v.respond_to?(:id_for_database) ? v.id_for_database : v }
|
||||
values.empty? ? nil : values
|
||||
else
|
||||
value = value.id_for_database if value.respond_to?(:id_for_database)
|
||||
value
|
||||
end
|
||||
end
|
||||
|
||||
begin
|
||||
Arel::Nodes::BoundSqlLiteral.new("(#{statement})", bound_values, nil)
|
||||
rescue Arel::BindError => error
|
||||
raise ActiveRecord::PreparedStatementInvalid, error.message
|
||||
end
|
||||
end
|
||||
|
||||
def lookup_table_klass_from_join_dependencies(table_name)
|
||||
each_join_dependencies do |join|
|
||||
return join.base_klass if table_name == join.table_name
|
||||
|
@ -89,19 +89,31 @@ def assert_no_queries_match(match, include_schema: false, &block)
|
||||
end
|
||||
|
||||
class SQLCounter # :nodoc:
|
||||
attr_reader :log, :log_all
|
||||
attr_reader :log_full, :log_all
|
||||
|
||||
def initialize
|
||||
@log = []
|
||||
@log_full = []
|
||||
@log_all = []
|
||||
end
|
||||
|
||||
def log
|
||||
@log_full.map(&:first)
|
||||
end
|
||||
|
||||
def call(*, payload)
|
||||
return if payload[:cached]
|
||||
|
||||
sql = payload[:sql]
|
||||
@log_all << sql
|
||||
@log << sql unless payload[:name] == "SCHEMA"
|
||||
|
||||
unless payload[:name] == "SCHEMA"
|
||||
bound_values = (payload[:binds] || []).map do |value|
|
||||
value = value.value_for_database if value.respond_to?(:value_for_database)
|
||||
value
|
||||
end
|
||||
|
||||
@log_full << [sql, bound_values]
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
@ -6,13 +6,17 @@ class BoundSqlLiteral < NodeExpression
|
||||
attr_reader :sql_with_placeholders, :positional_binds, :named_binds
|
||||
|
||||
def initialize(sql_with_placeholders, positional_binds, named_binds)
|
||||
if !positional_binds.empty? && !named_binds.empty?
|
||||
raise BindError.new("cannot mix positional and named binds", sql_with_placeholders)
|
||||
elsif !positional_binds.empty?
|
||||
has_positional = !(positional_binds.nil? || positional_binds.empty?)
|
||||
has_named = !(named_binds.nil? || named_binds.empty?)
|
||||
|
||||
if has_positional
|
||||
if has_named
|
||||
raise BindError.new("cannot mix positional and named binds", sql_with_placeholders)
|
||||
end
|
||||
if positional_binds.size != (expected = sql_with_placeholders.count("?"))
|
||||
raise BindError.new("wrong number of bind variables (#{positional_binds.size} for #{expected})", sql_with_placeholders)
|
||||
end
|
||||
elsif !named_binds.empty?
|
||||
elsif has_named
|
||||
tokens_in_string = sql_with_placeholders.scan(/:(?<!::)([a-zA-Z]\w*)/).flatten.map(&:to_sym).uniq
|
||||
tokens_in_hash = named_binds.keys.map(&:to_sym).uniq
|
||||
|
||||
@ -26,7 +30,7 @@ def initialize(sql_with_placeholders, positional_binds, named_binds)
|
||||
end
|
||||
|
||||
@sql_with_placeholders = sql_with_placeholders
|
||||
if !positional_binds.empty?
|
||||
if has_positional
|
||||
@positional_binds = positional_binds
|
||||
@named_binds = nil
|
||||
else
|
||||
|
@ -16,7 +16,8 @@ def table(table)
|
||||
end
|
||||
|
||||
def set(values)
|
||||
if String === values
|
||||
case values
|
||||
when String, Nodes::BoundSqlLiteral
|
||||
@ast.values = [values]
|
||||
else
|
||||
@ast.values = values.map { |column, value|
|
||||
|
@ -49,7 +49,7 @@ def test_create_null_bytes
|
||||
end
|
||||
|
||||
def test_where_with_string_for_string_column_using_bind_parameters
|
||||
assert_quoted_as "'Welcome to the weblog'", "Welcome to the weblog"
|
||||
assert_quoted_as "'Welcome to the weblog'", "Welcome to the weblog", match: 1
|
||||
end
|
||||
|
||||
def test_where_with_integer_for_string_column_using_bind_parameters
|
||||
@ -73,14 +73,16 @@ def test_where_with_rational_for_string_column_using_bind_parameters
|
||||
end
|
||||
|
||||
private
|
||||
def assert_quoted_as(expected, value)
|
||||
def assert_quoted_as(expected, value, match: 0)
|
||||
relation = Post.where("title = ?", value)
|
||||
assert_equal(
|
||||
%{SELECT `posts`.* FROM `posts` WHERE (title = #{expected})},
|
||||
relation.to_sql,
|
||||
)
|
||||
assert_nothing_raised do # Make sure SQL is valid
|
||||
relation.to_a
|
||||
if match == 0
|
||||
assert_empty relation.to_a
|
||||
else
|
||||
assert_equal match, relation.count
|
||||
end
|
||||
end
|
||||
end
|
||||
|
@ -10,44 +10,40 @@ class BindParameterTest < ActiveRecord::PostgreSQLTestCase
|
||||
fixtures :posts
|
||||
|
||||
def test_where_with_string_for_string_column_using_bind_parameters
|
||||
assert_quoted_as "'Welcome to the weblog'", "Welcome to the weblog"
|
||||
assert_quoted_as "'Welcome to the weblog'", "Welcome to the weblog", match: 1
|
||||
end
|
||||
|
||||
def test_where_with_integer_for_string_column_using_bind_parameters
|
||||
assert_quoted_as "0", 0, valid: false
|
||||
assert_quoted_as "0", 0
|
||||
end
|
||||
|
||||
def test_where_with_float_for_string_column_using_bind_parameters
|
||||
assert_quoted_as "0.0", 0.0, valid: false
|
||||
assert_quoted_as "0.0", 0.0
|
||||
end
|
||||
|
||||
def test_where_with_boolean_for_string_column_using_bind_parameters
|
||||
assert_quoted_as "FALSE", false, valid: false
|
||||
assert_quoted_as "FALSE", false
|
||||
end
|
||||
|
||||
def test_where_with_decimal_for_string_column_using_bind_parameters
|
||||
assert_quoted_as "0.0", BigDecimal(0), valid: false
|
||||
assert_quoted_as "0.0", BigDecimal(0)
|
||||
end
|
||||
|
||||
def test_where_with_rational_for_string_column_using_bind_parameters
|
||||
assert_quoted_as "0/1", Rational(0), valid: false
|
||||
assert_quoted_as "0/1", Rational(0)
|
||||
end
|
||||
|
||||
private
|
||||
def assert_quoted_as(expected, value, valid: true)
|
||||
def assert_quoted_as(expected, value, match: 0)
|
||||
relation = Post.where("title = ?", value)
|
||||
assert_equal(
|
||||
%{SELECT "posts".* FROM "posts" WHERE (title = #{expected})},
|
||||
relation.to_sql,
|
||||
)
|
||||
if valid
|
||||
assert_nothing_raised do # Make sure SQL is valid
|
||||
relation.to_a
|
||||
end
|
||||
if match == 0
|
||||
assert_empty relation.to_a
|
||||
else
|
||||
assert_raises ActiveRecord::StatementInvalid do
|
||||
relation.to_a
|
||||
end
|
||||
assert_equal match, relation.count
|
||||
end
|
||||
end
|
||||
end
|
||||
|
@ -10,7 +10,7 @@ class BindParameterTest < ActiveRecord::SQLite3TestCase
|
||||
fixtures :posts
|
||||
|
||||
def test_where_with_string_for_string_column_using_bind_parameters
|
||||
assert_quoted_as "'Welcome to the weblog'", "Welcome to the weblog"
|
||||
assert_quoted_as "'Welcome to the weblog'", "Welcome to the weblog", match: 1
|
||||
end
|
||||
|
||||
def test_where_with_integer_for_string_column_using_bind_parameters
|
||||
@ -34,14 +34,16 @@ def test_where_with_rational_for_string_column_using_bind_parameters
|
||||
end
|
||||
|
||||
private
|
||||
def assert_quoted_as(expected, value)
|
||||
def assert_quoted_as(expected, value, match: 0)
|
||||
relation = Post.where("title = ?", value)
|
||||
assert_equal(
|
||||
%{SELECT "posts".* FROM "posts" WHERE (title = #{expected})},
|
||||
relation.to_sql,
|
||||
)
|
||||
assert_nothing_raised do # Make sure SQL is valid
|
||||
relation.to_a
|
||||
if match == 0
|
||||
assert_empty relation.to_a
|
||||
else
|
||||
assert_equal match, relation.count
|
||||
end
|
||||
end
|
||||
end
|
||||
|
@ -252,9 +252,9 @@ def test_default_scope_on_relations_is_not_cached
|
||||
assert_equal 0, counter
|
||||
comment = comments.first
|
||||
assert_equal 0, counter
|
||||
sql = capture_sql { comment.post }
|
||||
queries = capture_sql_and_binds { comment.post }
|
||||
comment.reload
|
||||
assert_not_equal sql, capture_sql { comment.post }
|
||||
assert_not_equal queries, capture_sql_and_binds { comment.post }
|
||||
end
|
||||
|
||||
def test_proxy_assignment
|
||||
|
@ -171,9 +171,9 @@ def test_default_scope_on_relations_is_not_cached
|
||||
assert_equal 0, counter
|
||||
post = posts.first
|
||||
assert_equal 0, counter
|
||||
sql = capture_sql { post.comments.to_a }
|
||||
queries = capture_sql_and_binds { post.comments.to_a }
|
||||
post.comments.reset
|
||||
assert_not_equal sql, capture_sql { post.comments.to_a }
|
||||
assert_not_equal queries, capture_sql_and_binds { post.comments.to_a }
|
||||
end
|
||||
|
||||
def test_has_many_build_with_options
|
||||
|
@ -102,7 +102,7 @@ def test_statement_cache_with_sql_string_literal
|
||||
|
||||
topics = Topic.where("topics.id = ?", 1)
|
||||
assert_equal [1], topics.map(&:id)
|
||||
assert_not_includes statement_cache, to_sql_key(topics.arel)
|
||||
assert_includes statement_cache, to_sql_key(topics.arel)
|
||||
end
|
||||
|
||||
def test_too_many_binds
|
||||
|
@ -200,7 +200,7 @@ def test_exists
|
||||
assert_equal false, Topic.exists?(9999999999999999999999999999999)
|
||||
assert_equal false, Topic.exists?(Topic.new.id)
|
||||
|
||||
assert_raise(NoMethodError) { Topic.exists?([1, 2]) }
|
||||
assert_raise(ArgumentError) { Topic.exists?([1, 2]) }
|
||||
end
|
||||
|
||||
def test_exists_with_scope
|
||||
|
@ -146,14 +146,20 @@ def test_merge_doesnt_duplicate_same_clauses
|
||||
|
||||
only_david = Author.where("#{author_id} IN (?)", david)
|
||||
|
||||
if current_adapter?(:Mysql2Adapter, :TrilogyAdapter)
|
||||
assert_queries_match(/WHERE \(#{Regexp.escape(author_id)} IN \('1'\)\)\z/) do
|
||||
assert_equal [david], only_david.merge(only_david)
|
||||
matcher = if Author.connection.prepared_statements
|
||||
if current_adapter?(:PostgreSQLAdapter)
|
||||
/WHERE \(#{Regexp.escape(author_id)} IN \(\$1\)\)\z/
|
||||
else
|
||||
/WHERE \(#{Regexp.escape(author_id)} IN \(\?\)\)\z/
|
||||
end
|
||||
elsif current_adapter?(:Mysql2Adapter, :TrilogyAdapter)
|
||||
/WHERE \(#{Regexp.escape(author_id)} IN \('1'\)\)\z/
|
||||
else
|
||||
assert_queries_match(/WHERE \(#{Regexp.escape(author_id)} IN \(1\)\)\z/) do
|
||||
assert_equal [david], only_david.merge(only_david)
|
||||
end
|
||||
/WHERE \(#{Regexp.escape(author_id)} IN \(1\)\)\z/
|
||||
end
|
||||
|
||||
assert_queries_match(matcher) do
|
||||
assert_equal [david], only_david.merge(only_david)
|
||||
end
|
||||
end
|
||||
|
||||
|
@ -204,7 +204,7 @@ def self.sanitize_sql(args)
|
||||
|
||||
relation = Relation.new(klass)
|
||||
relation.merge!(where: ["foo = ?", "bar"])
|
||||
assert_equal Relation::WhereClause.new(["foo = bar"]), relation.where_clause
|
||||
assert_equal Relation::WhereClause.new([Arel.sql("(foo = ?)", "bar")]), relation.where_clause
|
||||
end
|
||||
|
||||
def test_merging_readonly_false
|
||||
|
@ -476,9 +476,9 @@ def test_finding_with_complex_order
|
||||
def test_finding_with_sanitized_order
|
||||
query = Tag.order([Arel.sql("field(id, ?)"), [1, 3, 2]]).to_sql
|
||||
if current_adapter?(:Mysql2Adapter, :TrilogyAdapter)
|
||||
assert_match(/field\(id, '1','3','2'\)/, query)
|
||||
assert_match(/field\(id, '1',\s*'3',\s*'2'\)/, query)
|
||||
else
|
||||
assert_match(/field\(id, 1,3,2\)/, query)
|
||||
assert_match(/field\(id, 1,\s*3,\s*2\)/, query)
|
||||
end
|
||||
|
||||
query = Tag.order([Arel.sql("field(id, ?)"), []]).to_sql
|
||||
|
@ -91,11 +91,21 @@ def self.search_as_method(term)
|
||||
}
|
||||
end
|
||||
|
||||
assert_queries_match(/LIKE '20!% !_reduction!_!!'/) do
|
||||
query = if searchable_post.connection.prepared_statements
|
||||
if current_adapter?(:PostgreSQLAdapter)
|
||||
/title LIKE \$1/
|
||||
else
|
||||
/title LIKE \?/
|
||||
end
|
||||
else
|
||||
/LIKE '20!% !_reduction!_!!'/
|
||||
end
|
||||
|
||||
assert_queries_match(query) do
|
||||
searchable_post.search_as_method("20% _reduction_!").to_a
|
||||
end
|
||||
|
||||
assert_queries_match(/LIKE '20!% !_reduction!_!!'/) do
|
||||
assert_queries_match(query) do
|
||||
searchable_post.search_as_scope("20% _reduction_!").to_a
|
||||
end
|
||||
end
|
||||
|
@ -99,6 +99,14 @@ def capture_sql(include_schema: false)
|
||||
end
|
||||
end
|
||||
|
||||
def capture_sql_and_binds
|
||||
counter = SQLCounter.new
|
||||
ActiveSupport::Notifications.subscribed(counter, "sql.active_record") do
|
||||
yield
|
||||
counter.log_full
|
||||
end
|
||||
end
|
||||
|
||||
# Redefine existing assertion method to explicitly not materialize transactions.
|
||||
def assert_queries_match(match, count: nil, include_schema: false, &block)
|
||||
counter = SQLCounter.new
|
||||
|
Loading…
Reference in New Issue
Block a user