Refactor building the update manager

This commit is contained in:
Jon Leighton 2011-08-10 00:06:30 +01:00
parent 43b99f290a
commit fe0ec85541
4 changed files with 13 additions and 16 deletions

@ -310,12 +310,10 @@ def sanitize_limit(limit)
# on mysql (even when aliasing the tables), but mysql allows using JOIN directly in
# an UPDATE statement, so in the mysql adapters we redefine this to do that.
def join_to_update(update, select) #:nodoc:
subselect = select.clone
subselect.ast.cores.last.projections = [update.ast.key]
subselect = select.ast.clone
subselect.cores.last.projections = [update.ast.key]
update.ast.limit = nil
update.ast.orders = []
update.wheres = [update.ast.key.in(subselect)]
update.where update.ast.key.in(subselect)
end
protected

@ -594,11 +594,10 @@ def join_to_update(update, select) #:nodoc:
subselect = Arel::SelectManager.new(select.engine, subsubselect)
subselect.project(Arel::Table.new('__active_record_temp')[update.ast.key.name])
update.ast.limit = nil
update.ast.orders = []
update.wheres = [update.ast.key.in(subselect)]
update.where update.ast.key.in(subselect)
else
update.table select.ast.cores.last.source
update.wheres = select.constraints
end
end

@ -508,11 +508,10 @@ def join_to_update(update, select) #:nodoc:
subselect = Arel::SelectManager.new(select.engine, subsubselect)
subselect.project(Arel::Table.new('__active_record_temp')[update.ast.key.name])
update.ast.limit = nil
update.ast.orders = []
update.wheres = [update.ast.key.in(subselect)]
update.where update.ast.key.in(subselect)
else
update.table select.ast.cores.last.source
update.wheres = select.constraints
end
end

@ -216,17 +216,18 @@ def update_all(updates, conditions = nil, options = {})
if conditions || options.present?
where(conditions).apply_finder_options(options.slice(:limit, :order)).update_all(updates)
else
stmt = arel.compile_update(Arel.sql(@klass.send(:sanitize_sql_for_assignment, updates)))
stmt = Arel::UpdateManager.new(arel.engine)
stmt.set Arel.sql(@klass.send(:sanitize_sql_for_assignment, updates))
stmt.table(table)
stmt.key = table[primary_key]
if joins_values.any?
@klass.connection.join_to_update(stmt, arel)
else
if limit = arel.limit
stmt.take limit
end
stmt.take(arel.limit)
stmt.order(*arel.orders)
stmt.wheres = arel.constraints
end
@klass.connection.update stmt, 'SQL', bind_values