diff --git a/activerecord/lib/active_record/relation.rb b/activerecord/lib/active_record/relation.rb index 34ea09331c..f4900fab59 100644 --- a/activerecord/lib/active_record/relation.rb +++ b/activerecord/lib/active_record/relation.rb @@ -422,18 +422,20 @@ def cache_key_with_version # Please check unscoped if you want to remove all previous scopes (including # the default_scope) during the execution of a block. def scoping(all_queries: nil) - if global_scope? && all_queries == false + registry = klass.scope_registry + if global_scope?(registry) && all_queries == false raise ArgumentError, "Scoping is set to apply to all queries and cannot be unset in a nested block." - elsif already_in_scope? + elsif already_in_scope?(registry) yield else - _scoping(self, all_queries) { yield } + _scoping(self, registry, all_queries) { yield } end end def _exec_scope(*args, &block) # :nodoc: @delegate_to_klass = true - _scoping(nil) { instance_exec(*args, &block) || self } + registry = klass.scope_registry + _scoping(nil, registry) { instance_exec(*args, &block) || self } ensure @delegate_to_klass = false end @@ -830,12 +832,12 @@ def null_relation? # :nodoc: end private - def already_in_scope? - @delegate_to_klass && klass.current_scope(true) + def already_in_scope?(registry) + @delegate_to_klass && registry.current_scope(klass, true) end - def global_scope? - klass.global_current_scope(true) + def global_scope?(registry) + registry.global_current_scope(klass, true) end def current_scope_restoring_block(&block) @@ -858,16 +860,19 @@ def _create!(attributes, &block) klass.create!(attributes, &block) end - def _scoping(scope, all_queries = false) - previous, klass.current_scope = klass.current_scope(true), scope + def _scoping(scope, registry, all_queries = false) + previous = registry.current_scope(klass, true) + registry.set_current_scope(klass, scope) + if all_queries - previous_global, klass.global_current_scope = klass.global_current_scope(true), scope + previous_global = registry.global_current_scope(klass, true) + registry.set_global_current_scope(klass, scope) end yield ensure - klass.current_scope = previous + registry.set_current_scope(klass, previous) if all_queries - klass.global_current_scope = previous_global + registry.set_global_current_scope(klass, previous_global) end end diff --git a/activerecord/lib/active_record/relation/spawn_methods.rb b/activerecord/lib/active_record/relation/spawn_methods.rb index 8636b34434..5f3be4733e 100644 --- a/activerecord/lib/active_record/relation/spawn_methods.rb +++ b/activerecord/lib/active_record/relation/spawn_methods.rb @@ -8,7 +8,7 @@ module ActiveRecord module SpawnMethods # This is overridden by Associations::CollectionProxy def spawn #:nodoc: - already_in_scope? ? klass.all : clone + already_in_scope?(klass.scope_registry) ? klass.all : clone end # Merges in the conditions from other, if other is an ActiveRecord::Relation. diff --git a/activerecord/lib/active_record/scoping.rb b/activerecord/lib/active_record/scoping.rb index 54eeb952a3..96d3cba49c 100644 --- a/activerecord/lib/active_record/scoping.rb +++ b/activerecord/lib/active_record/scoping.rb @@ -24,19 +24,23 @@ def scope_attributes? end def current_scope(skip_inherited_scope = false) - ScopeRegistry.value_for(:current_scope, self, skip_inherited_scope) + ScopeRegistry.current_scope(self, skip_inherited_scope) end def current_scope=(scope) - ScopeRegistry.set_value_for(:current_scope, self, scope) + ScopeRegistry.set_current_scope(self, scope) end def global_current_scope(skip_inherited_scope = false) - ScopeRegistry.value_for(:global_current_scope, self, skip_inherited_scope) + ScopeRegistry.global_current_scope(self, skip_inherited_scope) end def global_current_scope=(scope) - ScopeRegistry.set_value_for(:global_current_scope, self, scope) + ScopeRegistry.set_global_current_scope(self, scope) + end + + def scope_registry + ScopeRegistry.instance end end @@ -80,34 +84,40 @@ class ScopeRegistry # :nodoc: VALID_SCOPE_TYPES = [:current_scope, :ignore_default_scope, :global_current_scope] def initialize - @registry = Hash.new { |hash, key| hash[key] = {} } + @current_scope = {} + @ignore_default_scope = {} + @global_current_scope = {} end - # Obtains the value for a given +scope_type+ and +model+. - def value_for(scope_type, model, skip_inherited_scope = false) - raise_invalid_scope_type!(scope_type) - return @registry[scope_type][model.name] if skip_inherited_scope - klass = model - base = model.base_class - while klass <= base - value = @registry[scope_type][klass.name] - return value if value - klass = klass.superclass + VALID_SCOPE_TYPES.each do |type| + class_eval <<-eorb, __FILE__, __LINE__ + def #{type}(model, skip_inherited_scope = false) + value_for(@#{type}, model, skip_inherited_scope) end - end - # Sets the +value+ for a given +scope_type+ and +model+. - def set_value_for(scope_type, model, value) - raise_invalid_scope_type!(scope_type) - @registry[scope_type][model.name] = value + def set_#{type}(model, value) + set_value_for(@#{type}, model, value) + end + eorb end private - def raise_invalid_scope_type!(scope_type) - if !VALID_SCOPE_TYPES.include?(scope_type) - raise ArgumentError, "Invalid scope type '#{scope_type}' sent to the registry. Scope types must be included in VALID_SCOPE_TYPES" + # Obtains the value for a given +scope_type+ and +model+. + def value_for(scope_type, model, skip_inherited_scope = false) + return scope_type[model.name] if skip_inherited_scope + klass = model + base = model.base_class + while klass <= base + value = scope_type[klass.name] + return value if value + klass = klass.superclass end end + + # Sets the +value+ for a given +scope_type+ and +model+. + def set_value_for(scope_type, model, value) + scope_type[model.name] = value + end end end end diff --git a/activerecord/lib/active_record/scoping/default.rb b/activerecord/lib/active_record/scoping/default.rb index c748eea5c9..0b81cc461e 100644 --- a/activerecord/lib/active_record/scoping/default.rb +++ b/activerecord/lib/active_record/scoping/default.rb @@ -173,11 +173,11 @@ def execute_scope?(all_queries, default_scope_obj) end def ignore_default_scope? - ScopeRegistry.value_for(:ignore_default_scope, base_class) + ScopeRegistry.ignore_default_scope(base_class) end def ignore_default_scope=(ignore) - ScopeRegistry.set_value_for(:ignore_default_scope, base_class, ignore) + ScopeRegistry.set_ignore_default_scope(base_class, ignore) end # The ignore_default_scope flag is used to prevent an infinite recursion diff --git a/activerecord/test/cases/base_test.rb b/activerecord/test/cases/base_test.rb index 51f9d476f0..beb1a8a8ff 100644 --- a/activerecord/test/cases/base_test.rb +++ b/activerecord/test/cases/base_test.rb @@ -1259,9 +1259,9 @@ def test_current_scope_is_reset UnloadablePost.unloadable klass = UnloadablePost - assert_not_nil ActiveRecord::Scoping::ScopeRegistry.value_for(:current_scope, klass) + assert_not_nil ActiveRecord::Scoping::ScopeRegistry.current_scope(klass) ActiveSupport::Dependencies.remove_unloadable_constants! - assert_nil ActiveRecord::Scoping::ScopeRegistry.value_for(:current_scope, klass) + assert_nil ActiveRecord::Scoping::ScopeRegistry.current_scope(klass) ensure Object.class_eval { remove_const :UnloadablePost } if defined?(UnloadablePost) end diff --git a/activerecord/test/models/post.rb b/activerecord/test/models/post.rb index d55d1e16ae..15f7f55725 100644 --- a/activerecord/test/models/post.rb +++ b/activerecord/test/models/post.rb @@ -324,6 +324,10 @@ class FakeKlass extend ActiveRecord::Delegation::DelegateCache class << self + def scope_registry + ActiveRecord::Scoping::ScopeRegistry.instance + end + def connection Post.connection end