diff --git a/activerecord/lib/active_record/associations/association_collection.rb b/activerecord/lib/active_record/associations/association_collection.rb index ec7af7913d..552c1ed06e 100644 --- a/activerecord/lib/active_record/associations/association_collection.rb +++ b/activerecord/lib/active_record/associations/association_collection.rb @@ -163,11 +163,6 @@ def find_target end private - # Array#flatten has problems with recursive arrays. Going one level deeper solves the majority of the problems. - def flatten_deeper(array) - array.collect { |element| element.respond_to?(:flatten) ? element.flatten : element }.flatten - end - def callback(method, record) callbacks_for(method).each do |callback| case callback diff --git a/activerecord/lib/active_record/associations/association_proxy.rb b/activerecord/lib/active_record/associations/association_proxy.rb index c093e40ca0..1d313fc537 100644 --- a/activerecord/lib/active_record/associations/association_proxy.rb +++ b/activerecord/lib/active_record/associations/association_proxy.rb @@ -146,6 +146,11 @@ def raise_on_type_mismatch(record) raise ActiveRecord::AssociationTypeMismatch, "#{@reflection.class_name} expected, got #{record.class}" end end + + # Array#flatten has problems with recursive arrays. Going one level deeper solves the majority of the problems. + def flatten_deeper(array) + array.collect { |element| element.respond_to?(:flatten) ? element.flatten : element }.flatten + end end end end diff --git a/activerecord/lib/active_record/associations/has_many_through_association.rb b/activerecord/lib/active_record/associations/has_many_through_association.rb index b5feeff1d2..0257b1857d 100644 --- a/activerecord/lib/active_record/associations/has_many_through_association.rb +++ b/activerecord/lib/active_record/associations/has_many_through_association.rb @@ -44,18 +44,26 @@ def reset # must have ids in order to create records associating them, so this # will raise ActiveRecord::HasManyThroughCantAssociateNewRecords if # either is a new record. Calls create! so you can rescue errors. - def <<(*args) - return if args.empty? + # + # The :before_add and :after_add callbacks are not yet supported. + def <<(*records) + return if records.empty? through = @reflection.through_reflection raise ActiveRecord::HasManyThroughCantAssociateNewRecords.new(@owner, through) if @owner.new_record? + load_target + klass = through.klass klass.transaction do - args.each do |associate| + flatten_deeper(records).each do |associate| + raise_on_type_mismatch(associate) raise ActiveRecord::HasManyThroughCantAssociateNewRecords.new(@owner, through) unless associate.respond_to?(:new_record?) && !associate.new_record? - klass.with_scope(:create => construct_join_attributes(associate)) { klass.create! } + + @target << klass.with_scope(:create => construct_join_attributes(associate)) { klass.create! } end end + + self end [:push, :concat].each { |method| alias_method method, :<< } diff --git a/activerecord/test/associations_join_model_test.rb b/activerecord/test/associations_join_model_test.rb index 0f6a5e76c2..1a0882fcfb 100644 --- a/activerecord/test/associations_join_model_test.rb +++ b/activerecord/test/associations_join_model_test.rb @@ -373,15 +373,28 @@ def test_create_associate_when_adding_to_has_many_through count = posts(:thinking).tags.count push = Tag.create!(:name => 'pushme') assert_nothing_raised { posts(:thinking).tags << push } + assert_equal(count + 1, posts(:thinking).tags.size) assert_equal(count + 1, posts(:thinking).tags(true).size) assert_nothing_raised { posts(:thinking).tags.create!(:name => 'foo') } + assert_equal(count + 2, posts(:thinking).tags.size) assert_equal(count + 2, posts(:thinking).tags(true).size) assert_nothing_raised { posts(:thinking).tags.concat(Tag.create!(:name => 'abc'), Tag.create!(:name => 'def')) } + assert_equal(count + 4, posts(:thinking).tags.size) assert_equal(count + 4, posts(:thinking).tags(true).size) end + def test_adding_junk_to_has_many_through_should_raise_type_mismatch + assert_raise(ActiveRecord::AssociationTypeMismatch) { posts(:thinking).tags << "Uhh what now?" } + end + + def test_adding_to_has_many_through_should_return_self + tags = posts(:thinking).tags + assert_equal tags, posts(:thinking).tags.push(tags(:general)) + end + + def test_has_many_through_sum_uses_calculations assert_nothing_raised { authors(:david).comments.sum(:post_id) } end