Merge pull request #51139 from Shopify/relation-bound-sql-literal

Relation#where build BoundSqlLiteral rather than eagerly interpolate
This commit is contained in:
Jean Boussier 2024-02-21 16:49:45 +01:00 committed by GitHub
commit 684131a4f0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
19 changed files with 146 additions and 54 deletions

@ -21,9 +21,6 @@ def cast_bound_value(value)
"1"
when false
"0"
when ActiveSupport::Duration
warn_quote_duration_deprecated
value.to_s
else
value
end

@ -141,6 +141,8 @@ def type_cast(value) # :nodoc:
encode_array(value)
when Range
encode_range(value)
when Rational
value.to_f
else
super
end

@ -63,7 +63,7 @@ def quote_default_expression(value, column) # :nodoc:
def type_cast(value) # :nodoc:
case value
when BigDecimal
when BigDecimal, Rational
value.to_f
when String
if value.encoding == Encoding::ASCII_8BIT

@ -1512,9 +1512,21 @@ def build_subquery(subquery_alias, select_value) # :nodoc:
def build_where_clause(opts, rest = []) # :nodoc:
opts = sanitize_forbidden_attributes(opts)
if opts.is_a?(Array)
opts, *rest = opts
end
case opts
when String, Array
parts = [klass.sanitize_sql(rest.empty? ? opts : [opts, *rest])]
when String
if rest.empty?
parts = [Arel.sql(opts)]
elsif rest.first.is_a?(Hash) && /:\w+/.match?(opts)
parts = [build_named_bound_sql_literal(opts, rest.first)]
elsif opts.include?("?")
parts = [build_bound_sql_literal(opts, rest)]
else
parts = [klass.sanitize_sql(rest.empty? ? opts : [opts, *rest])]
end
when Hash
opts = opts.transform_keys do |key|
if key.is_a?(Array)
@ -1550,6 +1562,46 @@ def async
spawn.async!
end
def build_named_bound_sql_literal(statement, values)
bound_values = values.transform_values do |value|
if ActiveRecord::Relation === value
Arel.sql(value.to_sql)
elsif value.respond_to?(:map) && !value.acts_like?(:string)
values = value.map { |v| v.respond_to?(:id_for_database) ? v.id_for_database : v }
values.empty? ? nil : values
else
value = value.id_for_database if value.respond_to?(:id_for_database)
value
end
end
begin
Arel::Nodes::BoundSqlLiteral.new("(#{statement})", nil, bound_values)
rescue Arel::BindError => error
raise ActiveRecord::PreparedStatementInvalid, error.message
end
end
def build_bound_sql_literal(statement, values)
bound_values = values.map do |value|
if ActiveRecord::Relation === value
Arel.sql(value.to_sql)
elsif value.respond_to?(:map) && !value.acts_like?(:string)
values = value.map { |v| v.respond_to?(:id_for_database) ? v.id_for_database : v }
values.empty? ? nil : values
else
value = value.id_for_database if value.respond_to?(:id_for_database)
value
end
end
begin
Arel::Nodes::BoundSqlLiteral.new("(#{statement})", bound_values, nil)
rescue Arel::BindError => error
raise ActiveRecord::PreparedStatementInvalid, error.message
end
end
def lookup_table_klass_from_join_dependencies(table_name)
each_join_dependencies do |join|
return join.base_klass if table_name == join.table_name

@ -89,19 +89,31 @@ def assert_no_queries_match(match, include_schema: false, &block)
end
class SQLCounter # :nodoc:
attr_reader :log, :log_all
attr_reader :log_full, :log_all
def initialize
@log = []
@log_full = []
@log_all = []
end
def log
@log_full.map(&:first)
end
def call(*, payload)
return if payload[:cached]
sql = payload[:sql]
@log_all << sql
@log << sql unless payload[:name] == "SCHEMA"
unless payload[:name] == "SCHEMA"
bound_values = (payload[:binds] || []).map do |value|
value = value.value_for_database if value.respond_to?(:value_for_database)
value
end
@log_full << [sql, bound_values]
end
end
end
end

@ -6,13 +6,17 @@ class BoundSqlLiteral < NodeExpression
attr_reader :sql_with_placeholders, :positional_binds, :named_binds
def initialize(sql_with_placeholders, positional_binds, named_binds)
if !positional_binds.empty? && !named_binds.empty?
raise BindError.new("cannot mix positional and named binds", sql_with_placeholders)
elsif !positional_binds.empty?
has_positional = !(positional_binds.nil? || positional_binds.empty?)
has_named = !(named_binds.nil? || named_binds.empty?)
if has_positional
if has_named
raise BindError.new("cannot mix positional and named binds", sql_with_placeholders)
end
if positional_binds.size != (expected = sql_with_placeholders.count("?"))
raise BindError.new("wrong number of bind variables (#{positional_binds.size} for #{expected})", sql_with_placeholders)
end
elsif !named_binds.empty?
elsif has_named
tokens_in_string = sql_with_placeholders.scan(/:(?<!::)([a-zA-Z]\w*)/).flatten.map(&:to_sym).uniq
tokens_in_hash = named_binds.keys.map(&:to_sym).uniq
@ -26,7 +30,7 @@ def initialize(sql_with_placeholders, positional_binds, named_binds)
end
@sql_with_placeholders = sql_with_placeholders
if !positional_binds.empty?
if has_positional
@positional_binds = positional_binds
@named_binds = nil
else

@ -16,7 +16,8 @@ def table(table)
end
def set(values)
if String === values
case values
when String, Nodes::BoundSqlLiteral
@ast.values = [values]
else
@ast.values = values.map { |column, value|

@ -49,7 +49,7 @@ def test_create_null_bytes
end
def test_where_with_string_for_string_column_using_bind_parameters
assert_quoted_as "'Welcome to the weblog'", "Welcome to the weblog"
assert_quoted_as "'Welcome to the weblog'", "Welcome to the weblog", match: 1
end
def test_where_with_integer_for_string_column_using_bind_parameters
@ -73,14 +73,16 @@ def test_where_with_rational_for_string_column_using_bind_parameters
end
private
def assert_quoted_as(expected, value)
def assert_quoted_as(expected, value, match: 0)
relation = Post.where("title = ?", value)
assert_equal(
%{SELECT `posts`.* FROM `posts` WHERE (title = #{expected})},
relation.to_sql,
)
assert_nothing_raised do # Make sure SQL is valid
relation.to_a
if match == 0
assert_empty relation.to_a
else
assert_equal match, relation.count
end
end
end

@ -10,44 +10,40 @@ class BindParameterTest < ActiveRecord::PostgreSQLTestCase
fixtures :posts
def test_where_with_string_for_string_column_using_bind_parameters
assert_quoted_as "'Welcome to the weblog'", "Welcome to the weblog"
assert_quoted_as "'Welcome to the weblog'", "Welcome to the weblog", match: 1
end
def test_where_with_integer_for_string_column_using_bind_parameters
assert_quoted_as "0", 0, valid: false
assert_quoted_as "0", 0
end
def test_where_with_float_for_string_column_using_bind_parameters
assert_quoted_as "0.0", 0.0, valid: false
assert_quoted_as "0.0", 0.0
end
def test_where_with_boolean_for_string_column_using_bind_parameters
assert_quoted_as "FALSE", false, valid: false
assert_quoted_as "FALSE", false
end
def test_where_with_decimal_for_string_column_using_bind_parameters
assert_quoted_as "0.0", BigDecimal(0), valid: false
assert_quoted_as "0.0", BigDecimal(0)
end
def test_where_with_rational_for_string_column_using_bind_parameters
assert_quoted_as "0/1", Rational(0), valid: false
assert_quoted_as "0/1", Rational(0)
end
private
def assert_quoted_as(expected, value, valid: true)
def assert_quoted_as(expected, value, match: 0)
relation = Post.where("title = ?", value)
assert_equal(
%{SELECT "posts".* FROM "posts" WHERE (title = #{expected})},
relation.to_sql,
)
if valid
assert_nothing_raised do # Make sure SQL is valid
relation.to_a
end
if match == 0
assert_empty relation.to_a
else
assert_raises ActiveRecord::StatementInvalid do
relation.to_a
end
assert_equal match, relation.count
end
end
end

@ -10,7 +10,7 @@ class BindParameterTest < ActiveRecord::SQLite3TestCase
fixtures :posts
def test_where_with_string_for_string_column_using_bind_parameters
assert_quoted_as "'Welcome to the weblog'", "Welcome to the weblog"
assert_quoted_as "'Welcome to the weblog'", "Welcome to the weblog", match: 1
end
def test_where_with_integer_for_string_column_using_bind_parameters
@ -34,14 +34,16 @@ def test_where_with_rational_for_string_column_using_bind_parameters
end
private
def assert_quoted_as(expected, value)
def assert_quoted_as(expected, value, match: 0)
relation = Post.where("title = ?", value)
assert_equal(
%{SELECT "posts".* FROM "posts" WHERE (title = #{expected})},
relation.to_sql,
)
assert_nothing_raised do # Make sure SQL is valid
relation.to_a
if match == 0
assert_empty relation.to_a
else
assert_equal match, relation.count
end
end
end

@ -252,9 +252,9 @@ def test_default_scope_on_relations_is_not_cached
assert_equal 0, counter
comment = comments.first
assert_equal 0, counter
sql = capture_sql { comment.post }
queries = capture_sql_and_binds { comment.post }
comment.reload
assert_not_equal sql, capture_sql { comment.post }
assert_not_equal queries, capture_sql_and_binds { comment.post }
end
def test_proxy_assignment

@ -171,9 +171,9 @@ def test_default_scope_on_relations_is_not_cached
assert_equal 0, counter
post = posts.first
assert_equal 0, counter
sql = capture_sql { post.comments.to_a }
queries = capture_sql_and_binds { post.comments.to_a }
post.comments.reset
assert_not_equal sql, capture_sql { post.comments.to_a }
assert_not_equal queries, capture_sql_and_binds { post.comments.to_a }
end
def test_has_many_build_with_options

@ -102,7 +102,7 @@ def test_statement_cache_with_sql_string_literal
topics = Topic.where("topics.id = ?", 1)
assert_equal [1], topics.map(&:id)
assert_not_includes statement_cache, to_sql_key(topics.arel)
assert_includes statement_cache, to_sql_key(topics.arel)
end
def test_too_many_binds

@ -200,7 +200,7 @@ def test_exists
assert_equal false, Topic.exists?(9999999999999999999999999999999)
assert_equal false, Topic.exists?(Topic.new.id)
assert_raise(NoMethodError) { Topic.exists?([1, 2]) }
assert_raise(ArgumentError) { Topic.exists?([1, 2]) }
end
def test_exists_with_scope

@ -146,14 +146,20 @@ def test_merge_doesnt_duplicate_same_clauses
only_david = Author.where("#{author_id} IN (?)", david)
if current_adapter?(:Mysql2Adapter, :TrilogyAdapter)
assert_queries_match(/WHERE \(#{Regexp.escape(author_id)} IN \('1'\)\)\z/) do
assert_equal [david], only_david.merge(only_david)
matcher = if Author.connection.prepared_statements
if current_adapter?(:PostgreSQLAdapter)
/WHERE \(#{Regexp.escape(author_id)} IN \(\$1\)\)\z/
else
/WHERE \(#{Regexp.escape(author_id)} IN \(\?\)\)\z/
end
elsif current_adapter?(:Mysql2Adapter, :TrilogyAdapter)
/WHERE \(#{Regexp.escape(author_id)} IN \('1'\)\)\z/
else
assert_queries_match(/WHERE \(#{Regexp.escape(author_id)} IN \(1\)\)\z/) do
assert_equal [david], only_david.merge(only_david)
end
/WHERE \(#{Regexp.escape(author_id)} IN \(1\)\)\z/
end
assert_queries_match(matcher) do
assert_equal [david], only_david.merge(only_david)
end
end

@ -204,7 +204,7 @@ def self.sanitize_sql(args)
relation = Relation.new(klass)
relation.merge!(where: ["foo = ?", "bar"])
assert_equal Relation::WhereClause.new(["foo = bar"]), relation.where_clause
assert_equal Relation::WhereClause.new([Arel.sql("(foo = ?)", "bar")]), relation.where_clause
end
def test_merging_readonly_false

@ -476,9 +476,9 @@ def test_finding_with_complex_order
def test_finding_with_sanitized_order
query = Tag.order([Arel.sql("field(id, ?)"), [1, 3, 2]]).to_sql
if current_adapter?(:Mysql2Adapter, :TrilogyAdapter)
assert_match(/field\(id, '1','3','2'\)/, query)
assert_match(/field\(id, '1',\s*'3',\s*'2'\)/, query)
else
assert_match(/field\(id, 1,3,2\)/, query)
assert_match(/field\(id, 1,\s*3,\s*2\)/, query)
end
query = Tag.order([Arel.sql("field(id, ?)"), []]).to_sql

@ -91,11 +91,21 @@ def self.search_as_method(term)
}
end
assert_queries_match(/LIKE '20!% !_reduction!_!!'/) do
query = if searchable_post.connection.prepared_statements
if current_adapter?(:PostgreSQLAdapter)
/title LIKE \$1/
else
/title LIKE \?/
end
else
/LIKE '20!% !_reduction!_!!'/
end
assert_queries_match(query) do
searchable_post.search_as_method("20% _reduction_!").to_a
end
assert_queries_match(/LIKE '20!% !_reduction!_!!'/) do
assert_queries_match(query) do
searchable_post.search_as_scope("20% _reduction_!").to_a
end
end

@ -99,6 +99,14 @@ def capture_sql(include_schema: false)
end
end
def capture_sql_and_binds
counter = SQLCounter.new
ActiveSupport::Notifications.subscribed(counter, "sql.active_record") do
yield
counter.log_full
end
end
# Redefine existing assertion method to explicitly not materialize transactions.
def assert_queries_match(match, count: nil, include_schema: false, &block)
counter = SQLCounter.new