Merge pull request #50396 from Shopify/stricter-relation-delegation

Make the Relation -> Model delegation stricter
This commit is contained in:
Jean Boussier 2024-05-28 08:47:41 +02:00 committed by GitHub
commit 407031f8b8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 56 additions and 50 deletions

@ -17,12 +17,12 @@ def initialize(scope, association_key_name)
def eql?(other)
association_key_name == other.association_key_name &&
scope.table_name == other.scope.table_name &&
scope.connection_specification_name == other.scope.connection_specification_name &&
scope.model.connection_specification_name == other.scope.model.connection_specification_name &&
scope.values_for_queries == other.scope.values_for_queries
end
def hash
[association_key_name, scope.table_name, scope.connection_specification_name, scope.values_for_queries].hash
[association_key_name, scope.model.table_name, scope.model.connection_specification_name, scope.values_for_queries].hash
end
def records_for(loaders)

@ -41,6 +41,8 @@ def self.install_support
module EncryptedQuery # :nodoc:
class << self
def process_arguments(owner, args, check_for_additional_values)
owner = owner.model if owner.is_a?(Relation)
return args if owner.deterministic_encrypted_attributes&.empty?
if args.is_a?(Array) && (options = args.first).is_a?(Hash)

@ -437,7 +437,7 @@ def compute_cache_key(timestamp_column = :updated_at) # :nodoc:
query_signature = ActiveSupport::Digest.hexdigest(to_sql)
key = "#{klass.model_name.cache_key}/query-#{query_signature}"
if collection_cache_versioning
if model.collection_cache_versioning
key
else
"#{key}-#{compute_cache_version(timestamp_column)}"
@ -456,7 +456,7 @@ def compute_cache_key(timestamp_column = :updated_at) # :nodoc:
#
# SELECT COUNT(*), MAX("products"."updated_at") FROM "products" WHERE (name like '%Cosmic Encounter%')
def cache_version(timestamp_column = :updated_at)
if collection_cache_versioning
if model.collection_cache_versioning
@cache_versions ||= {}
@cache_versions[timestamp_column] ||= compute_cache_version(timestamp_column)
end
@ -475,7 +475,7 @@ def compute_cache_version(timestamp_column) # :nodoc:
with_connection do |c|
column = c.visitor.compile(table[timestamp_column])
select_values = "COUNT(*) AS #{adapter_class.quote_column_name("size")}, MAX(%s) AS timestamp"
select_values = "COUNT(*) AS #{klass.adapter_class.quote_column_name("size")}, MAX(%s) AS timestamp"
if collection.has_limit_or_offset?
query = collection.select("#{column} AS collection_cache_key_timestamp")
@ -501,7 +501,7 @@ def compute_cache_version(timestamp_column) # :nodoc:
end
if timestamp
"#{size}-#{timestamp.utc.to_fs(cache_timestamp_format)}"
"#{size}-#{timestamp.utc.to_fs(model.cache_timestamp_format)}"
else
"#{size}"
end
@ -1291,7 +1291,7 @@ def has_limit_or_offset? # :nodoc:
end
def alias_tracker(joins = [], aliases = nil) # :nodoc:
ActiveRecord::Associations::AliasTracker.create(connection_pool, table.name, joins, aliases)
ActiveRecord::Associations::AliasTracker.create(klass.connection_pool, table.name, joins, aliases)
end
class StrictLoadingScope # :nodoc:
@ -1451,7 +1451,7 @@ def instantiate_records(rows, &block)
def skip_query_cache_if_necessary(&block)
if skip_query_cache_value
uncached(&block)
model.uncached(&block)
else
yield
end

@ -327,8 +327,8 @@ def act_on_ignored_order(error_on_ignore)
if raise_error
raise ArgumentError.new(ORDER_IGNORE_MESSAGE)
elsif logger
logger.warn(ORDER_IGNORE_MESSAGE)
elsif model.logger
model.logger.warn(ORDER_IGNORE_MESSAGE)
end
end

@ -522,13 +522,13 @@ def execute_grouped_calculation(operation, column_name, distinct) # :nodoc:
column = aggregate_column(column_name)
column_alias = column_alias_tracker.alias_for("#{operation} #{column_name.to_s.downcase}")
select_value = operation_over_aggregate_column(column, operation, distinct)
select_value.as(adapter_class.quote_column_name(column_alias))
select_value.as(klass.adapter_class.quote_column_name(column_alias))
select_values = [select_value]
select_values += self.select_values unless having_clause.empty?
select_values.concat group_columns.map { |aliaz, field|
aliaz = adapter_class.quote_column_name(aliaz)
aliaz = klass.adapter_class.quote_column_name(aliaz)
if field.respond_to?(:as)
field.as(aliaz)
else
@ -633,6 +633,7 @@ def select_for_count
if select_values.present?
return select_values.first if select_values.one?
adapter_class = klass.adapter_class
select_values.map do |field|
column = arel_column(field.to_s) do |attr_name|
Arel.sql(attr_name)

@ -22,6 +22,9 @@ def uncacheable_methods
end
module DelegateCache # :nodoc:
@delegate_base_methods = true
singleton_class.attr_accessor :delegate_base_methods
def relation_delegate_class(klass)
@relation_delegate_cache[klass]
end
@ -100,7 +103,7 @@ def #{method}(...)
:to_sentence, :to_fs, :to_formatted_s, :as_json,
:shuffle, :split, :slice, :index, :rindex, to: :records
delegate :primary_key, :lease_connection, :connection, :with_connection, :transaction, to: :klass
delegate :primary_key, :with_connection, :connection, :table_name, :transaction, :sanitize_sql_like, :unscoped, to: :klass
module ClassSpecificRelation # :nodoc:
extend ActiveSupport::Concern
@ -114,9 +117,17 @@ def name
private
def method_missing(method, ...)
if @klass.respond_to?(method)
unless Delegation.uncacheable_methods.include?(method)
if !DelegateCache.delegate_base_methods && Base.respond_to?(method)
# A common mistake in Active Record's own code is to call `ActiveRecord::Base`
# class methods on Association. It works because it's automatically delegated, but
# can introduce subtle bugs because it sets the global scope.
# We can't deprecate this behavior because gems might depend on it, however we
# can ban it from Active Record's own test suite to avoid regressions.
raise NotImplementedError, "Active Record code shouldn't rely on association delegation into ActiveRecord::Base methods"
elsif !Delegation.uncacheable_methods.include?(method)
@klass.generate_relation_method(method)
end
scoping { @klass.public_send(method, ...) }
else
super

@ -145,10 +145,10 @@ def sole
if found.nil?
raise_record_not_found_exception!
elsif undesired.present?
raise ActiveRecord::SoleRecordExceeded.new(self)
else
elsif undesired.nil?
found
else
raise ActiveRecord::SoleRecordExceeded.new(model)
end
end
@ -376,7 +376,7 @@ def exists?(conditions = :none)
skip_query_cache_if_necessary do
with_connection do |c|
c.select_rows(relation.arel, "#{name} Exists?").size == 1
c.select_rows(relation.arel, "#{klass.name} Exists?").size == 1
end
end
end
@ -638,7 +638,7 @@ def find_last(limit)
end
def ordered_relation
if order_values.empty? && (implicit_order_column || !query_constraints_list.nil? || primary_key)
if order_values.empty? && (model.implicit_order_column || !model.query_constraints_list.nil? || primary_key)
order(_order_columns.map { |column| table[column].asc })
else
self
@ -648,11 +648,11 @@ def ordered_relation
def _order_columns
oc = []
oc << implicit_order_column if implicit_order_column
oc << query_constraints_list if query_constraints_list
oc << model.implicit_order_column if model.implicit_order_column
oc << model.query_constraints_list if model.query_constraints_list
if primary_key && query_constraints_list.nil?
oc << primary_key
if model.primary_key && model.query_constraints_list.nil?
oc << model.primary_key
end
oc.flatten.uniq.compact

@ -136,9 +136,10 @@ def missing(*associations)
private
def scope_association_reflection(association)
reflection = @scope.klass._reflect_on_association(association)
model = @scope.model
reflection = model._reflect_on_association(association)
unless reflection
raise ArgumentError.new("An association named `:#{association}` does not exist on the model `#{@scope.name}`.")
raise ArgumentError.new("An association named `:#{association}` does not exist on the model `#{model.name}`.")
end
reflection
end
@ -254,6 +255,10 @@ def includes!(*args) # :nodoc:
self
end
def all # :nodoc:
spawn
end
# Specify associations +args+ to be eager loaded using a <tt>LEFT OUTER JOIN</tt>.
# Performs a single query joining all specified associations. For example:
#
@ -703,7 +708,7 @@ def in_order_of(column, values)
references = column_references([column])
self.references_values |= references unless references.empty?
values = values.map { |value| type_caster.type_cast_for_database(column, value) }
values = values.map { |value| model.type_caster.type_cast_for_database(column, value) }
arel_column = column.is_a?(Arel::Nodes::SqlLiteral) ? column : order_column(column.to_s)
where_clause =
@ -1914,7 +1919,7 @@ def arel_columns(columns)
case field
when Symbol
arel_column(field.to_s) do |attr_name|
adapter_class.quote_table_name(attr_name)
klass.adapter_class.quote_table_name(attr_name)
end
when String
arel_column(field, &:itself)
@ -1946,7 +1951,7 @@ def arel_column(field)
def table_name_matches?(from)
table_name = Regexp.escape(table.name)
quoted_table_name = Regexp.escape(adapter_class.quote_table_name(table.name))
quoted_table_name = Regexp.escape(klass.adapter_class.quote_table_name(table.name))
/(?:\A|(?<!FROM)\s)(?:\b#{table_name}\b|#{quoted_table_name})(?!\.)/i.match?(from.to_s)
end
@ -2081,7 +2086,7 @@ def order_column(field)
if attr_name == "count" && !group_values.empty?
table[attr_name]
else
Arel.sql(adapter_class.quote_table_name(attr_name), retryable: true)
Arel.sql(klass.adapter_class.quote_table_name(attr_name), retryable: true)
end
end
end

@ -20,9 +20,9 @@ def exec_queries
QueryRegistry.reset
super.tap do |records|
if logger && ActiveRecord.warn_on_records_fetched_greater_than
if model.logger && ActiveRecord.warn_on_records_fetched_greater_than
if records.length > ActiveRecord.warn_on_records_fetched_greater_than
logger.warn "Query fetched #{records.size} #{@klass} records: #{QueryRegistry.queries.join(";")}"
model.logger.warn "Query fetched #{records.size} #{@klass} records: #{QueryRegistry.queries.join(";")}"
end
end
end

@ -823,21 +823,6 @@ def test_association_proxy_transaction_method_starts_transaction_in_association_
end
end
def test_caching_of_columns
david = Developer.find(1)
# clear cache possibly created by other tests
david.projects.reset_column_information
assert_queries_count(include_schema: true) { david.projects.columns }
assert_no_queries { david.projects.columns }
## and again to verify that reset_column_information clears the cache correctly
david.projects.reset_column_information
assert_queries_count(include_schema: true) { david.projects.columns }
assert_no_queries { david.projects.columns }
end
def test_attributes_are_being_set_when_initialized_from_habtm_association_with_where_clause
new_developer = projects(:action_controller).developers.where(name: "Marcelo").build
assert_equal "Marcelo", new_developer.name

@ -26,6 +26,8 @@
# to ensure it's not used internally.
ActiveRecord.permanent_connection_checkout = :disallowed
ActiveRecord::Delegation::DelegateCache.delegate_base_methods = false
# Disable available locale checks to avoid warnings running the test suite.
I18n.enforce_available_locales = false

@ -62,7 +62,7 @@ class QueryingMethodsDelegationTest < ActiveRecord::TestCase
ActiveRecord::SpawnMethods.public_instance_methods(false) - [:spawn, :merge!] +
ActiveRecord::QueryMethods.public_instance_methods(false).reject { |method|
method.end_with?("=", "!", "?", "value", "values", "clause")
} - [:reverse_order, :arel, :extensions, :construct_join_dependency] + [
} - [:all, :reverse_order, :arel, :extensions, :construct_join_dependency] + [
:any?, :many?, :none?, :one?,
:first_or_create, :first_or_create!, :first_or_initialize,
:find_or_create_by, :find_or_create_by!, :find_or_initialize_by,

@ -209,10 +209,10 @@ def ratings
has_many :posts_with_default_include, class_name: "PostWithDefaultInclude"
has_many :comments_on_posts_with_default_include, through: :posts_with_default_include, source: :comments
has_many :posts_with_signature, ->(record) { where(arel_table[:title].matches("%by #{record.name.downcase}%")) }, class_name: "Post"
has_many :posts_mentioning_author, ->(record = nil) { where(arel_table[:body].matches("%#{record&.name&.downcase}%")) }, class_name: "Post"
has_many :posts_with_signature, ->(record) { where(model.arel_table[:title].matches("%by #{record.name.downcase}%")) }, class_name: "Post"
has_many :posts_mentioning_author, ->(record = nil) { where(model.arel_table[:body].matches("%#{record&.name&.downcase}%")) }, class_name: "Post"
has_many :comments_on_posts_mentioning_author, through: :posts_mentioning_author, source: :comments
has_many :comments_mentioning_author, ->(record) { where(arel_table[:body].matches("%#{record.name.downcase}%")) }, through: :posts, source: :comments
has_many :comments_mentioning_author, ->(record) { where(model.arel_table[:body].matches("%#{record.name.downcase}%")) }, through: :posts, source: :comments
has_one :recent_post, -> { order(id: :desc) }, class_name: "Post"
has_one :recent_response, through: :recent_post, source: :comments