Allow where references association names as joined table alias names

If a table is joined multiple times, those tables are aliased other than
the first one.

It happens easily on self referential associations, and in that case
currently there is no way to work custom attribute (type casting) and
attribute alias resolution for aliased tables in `where` conditions.

To address the issue, it will allow `where` references association names
as table aliases. If association names are referenced in `where`, those
names are used for joined table alias names.

```ruby
class Comment < ActiveRecord::Base
  enum label: [:default, :child]
  has_many :children, class_name: "Comment", foreign_key: :parent_id
end

# ... FROM comments LEFT OUTER JOIN comments children ON ... WHERE children.label = 1
Comment.includes(:children).where("children.label": "child")
```

Fixes #39727.
This commit is contained in:
Ryuta Kamizono 2020-08-24 13:30:41 +09:00
parent 0ed1372c42
commit 999bf3d097
11 changed files with 41 additions and 19 deletions

@ -51,10 +51,13 @@ def initialize(connection, aliases)
@connection = connection
end
def aliased_table_for(arel_table)
if aliases[arel_table.name] == 0
def aliased_table_for(arel_table, table_name = nil)
table_name ||= arel_table.name
if aliases[table_name] == 0
# If it's zero, we can have our table_name
aliases[arel_table.name] = 1
aliases[table_name] = 1
arel_table = arel_table.alias(table_name) if arel_table.name != table_name
else
# Otherwise, we need to use an alias
aliased_name = @connection.table_alias_for(yield)

@ -78,9 +78,14 @@ def reflections
join_root.drop(1).map!(&:reflection)
end
def join_constraints(joins_to_add, alias_tracker)
def join_constraints(joins_to_add, alias_tracker, references)
@alias_tracker = alias_tracker
@joined_tables = {}
@references = {}
references.each do |table_name|
@references[table_name.to_sym] = table_name if table_name.is_a?(String)
end unless references.empty?
joins = make_join_constraints(join_root, join_type)
@ -190,7 +195,9 @@ def make_constraints(parent, child, join_type)
next table, true
end
table = alias_tracker.aliased_table_for(reflection.klass.arel_table) do
table_name = @references[reflection.name.to_sym]
table = alias_tracker.aliased_table_for(reflection.klass.arel_table, table_name) do
name = reflection.alias_candidate(parent.table_name)
root ? name : "#{name}_join"
end

@ -1015,7 +1015,7 @@ def derive_class_name
class PolymorphicReflection < AbstractReflection # :nodoc:
delegate :klass, :scope, :plural_name, :type, :join_primary_key, :join_foreign_key,
:scope_for, to: :@reflection
:name, :scope_for, to: :@reflection
def initialize(reflection, previous_reflection)
@reflection = reflection

@ -887,7 +887,7 @@ def references_eager_loaded_tables?
# always convert table names to downcase as in Oracle quoted table names are in uppercase
joined_tables.map!(&:downcase)
!(references_values - joined_tables).empty?
!(references_values.map(&:to_s) - joined_tables).empty?
end
def tables_in_string(string)

@ -238,8 +238,6 @@ def references(*table_names)
end
def references!(*table_names) # :nodoc:
table_names.map!(&:to_s)
self.references_values |= table_names
self
end
@ -1293,7 +1291,7 @@ def build_joins(join_sources, aliases = nil)
unless association_joins.empty? && stashed_joins.empty?
alias_tracker = alias_tracker(leading_joins + join_nodes, aliases)
join_dependency = construct_join_dependency(association_joins, join_type)
join_sources.concat(join_dependency.join_constraints(stashed_joins, alias_tracker))
join_sources.concat(join_dependency.join_constraints(stashed_joins, alias_tracker, references_values))
end
join_sources.concat(join_nodes) unless join_nodes.empty?

@ -199,6 +199,22 @@ def test_eager_loaded_has_one_association_with_references_does_not_run_additiona
assert_no_queries { authors.map(&:post) }
end
def test_type_cast_in_where_references_association_name
parent = comments(:greetings)
child = parent.children.create!(label: "child", body: "hi", post_id: parent.post_id)
comment = Comment.includes(:children).where("children.label": "child").last
assert_equal parent, comment
assert_equal [child], comment.children
end
def test_attribute_alias_in_where_references_association_name
firm = Firm.includes(:clients).where("clients.new_name": "Summit").last
assert_equal companies(:first_firm), firm
assert_equal [companies(:first_client)], firm.clients
end
def test_calculate_with_string_in_from_and_eager_loading
assert_equal 10, Post.from("authors, posts").eager_load(:comments).where("posts.author_id = authors.id").count
end

@ -113,7 +113,7 @@ def test_using_limitable_reflections_helper
def test_association_with_references
firm = companies(:first_firm)
assert_includes firm.association_with_references.references_values, "foo"
assert_equal [:foo], firm.association_with_references.references_values
end
end

@ -5,7 +5,7 @@
module ActiveRecord
class RelationMutationTest < ActiveRecord::TestCase
(Relation::MULTI_VALUE_METHODS - [:references, :extending, :order, :unscope, :select]).each do |method|
(Relation::MULTI_VALUE_METHODS - [:extending, :order, :unscope, :select]).each do |method|
test "##{method}!" do
assert relation.public_send("#{method}!", :foo).equal?(relation)
assert_equal [:foo], relation.public_send("#{method}_values")
@ -38,11 +38,6 @@ class RelationMutationTest < ActiveRecord::TestCase
end
end
test "#references!" do
assert relation.references!(:foo).equal?(relation)
assert_includes relation.references_values, "foo"
end
test "extending!" do
mod, mod2 = Module.new, Module.new

@ -140,13 +140,13 @@ def test_references_values
relation = Relation.new(FakeKlass)
assert_equal [], relation.references_values
relation = relation.references(:foo).references(:omg, :lol)
assert_equal ["foo", "omg", "lol"], relation.references_values
assert_equal [:foo, :omg, :lol], relation.references_values
end
def test_references_values_dont_duplicate
relation = Relation.new(FakeKlass)
relation = relation.references(:foo).references(:foo)
assert_equal ["foo"], relation.references_values
assert_equal [:foo], relation.references_values
end
test "merging a hash into a relation" do

@ -23,6 +23,8 @@ class Comment < ActiveRecord::Base
has_many :children, class_name: "Comment", foreign_key: :parent_id
belongs_to :parent, class_name: "Comment", counter_cache: :children_count
enum label: [:default, :child]
class ::OopsError < RuntimeError; end
module OopsExtension

@ -210,6 +210,7 @@
t.text :body, null: false
end
t.string :type
t.integer :label, default: 0
t.integer :tags_count, default: 0
t.integer :children_count, default: 0
t.integer :parent_id