Assign auto populated columns on Active Record object creation
This commit extends Active Record creation logic to allow for a database auto-populated attributes to be assigned on object creation. Given a `Post` model represented by the following schema: ```ruby create_table :posts, id: false do |t| t.integer :sequential_number, auto_increment: true t.string :title, primary_key: true t.string :ruby_on_rails, default: -> { "concat('R', 'o', 'R')" } end ``` where `title` is being used as a primary key, the table has an integer `sequential_number` column populated by a sequence and `ruby_on_rails` column has a default function - creation of `Post` records should populate the `sequential_number` and `ruby_on_rails` attributes: ```ruby new_post = Post.create(title: 'My first post') new_post.sequential_number # => 1 new_post.ruby_on_rails # => 'RoR' ``` * At this moment MySQL and SQLite adapters are limited to only one column being populated and the column must be the `auto_increment` while PostgreSQL adapter supports any number of auto-populated columns through `RETURNING` statement.
This commit is contained in:
parent
d026af43c3
commit
c92933265e
@ -1,3 +1,12 @@
|
||||
* Assign auto populated columns on Active Record record creation
|
||||
|
||||
Changes record creation logic to allow for the `auto_increment` column to be assigned
|
||||
right after creation regardless of it's relation to model's primary key.
|
||||
PostgreSQL adapter benefits the most from the change allowing for any number of auto-populated
|
||||
columns to be assigned on the object immediately after row insertion utilizing the `RETURNING` statement.
|
||||
|
||||
*Nikita Vasilevsky*
|
||||
|
||||
* Use the first key in the `shards` hash from `connected_to` for the `default_shard`.
|
||||
|
||||
Some applications may not want to use `:default` as a shard name in their connection model. Unfortunately Active Record expects there to be a `:default` shard because it must assume a shard to get the right connection from the pool manager. Rather than force applications to manually set this, `connects_to` can infer the default shard name from the hash of shards and will now assume that the first shard is your default.
|
||||
|
@ -145,8 +145,11 @@ def exec_query(sql, name = "SQL", binds = [], prepare: false)
|
||||
# Executes insert +sql+ statement in the context of this connection using
|
||||
# +binds+ as the bind substitutes. +name+ is logged along with
|
||||
# the executed +sql+ statement.
|
||||
def exec_insert(sql, name = nil, binds = [], pk = nil, sequence_name = nil)
|
||||
sql, binds = sql_for_insert(sql, pk, binds)
|
||||
# Some adapters support the `returning` keyword argument which allows to control the result of the query:
|
||||
# `nil` is the default value and maintains default behavior. If an array of column names is passed -
|
||||
# the result will contain values of the specified columns from the inserted row.
|
||||
def exec_insert(sql, name = nil, binds = [], pk = nil, sequence_name = nil, returning: nil)
|
||||
sql, binds = sql_for_insert(sql, pk, binds, returning)
|
||||
internal_exec_query(sql, name, binds)
|
||||
end
|
||||
|
||||
@ -180,10 +183,14 @@ def explain(arel, binds = [], options = []) # :nodoc:
|
||||
#
|
||||
# If the next id was calculated in advance (as in Oracle), it should be
|
||||
# passed in as +id_value+.
|
||||
def insert(arel, name = nil, pk = nil, id_value = nil, sequence_name = nil, binds = [])
|
||||
# Some adapters support the `returning` keyword argument which allows defining the return value of the method:
|
||||
# `nil` is the default value and maintains default behavior. If an array of column names is passed -
|
||||
# an array of is returned from the method representing values of the specified columns from the inserted row.
|
||||
def insert(arel, name = nil, pk = nil, id_value = nil, sequence_name = nil, binds = [], returning: nil)
|
||||
sql, binds = to_sql_and_binds(arel, binds)
|
||||
value = exec_insert(sql, name, binds, pk, sequence_name)
|
||||
id_value || last_inserted_id(value)
|
||||
value = exec_insert(sql, name, binds, pk, sequence_name, returning: returning)
|
||||
return id_value if id_value
|
||||
returning.nil? ? last_inserted_id(value) : returning_column_values(value)
|
||||
end
|
||||
alias create insert
|
||||
|
||||
@ -626,7 +633,7 @@ def select(sql, name = nil, binds = [], prepare: false, async: false)
|
||||
end
|
||||
end
|
||||
|
||||
def sql_for_insert(sql, pk, binds)
|
||||
def sql_for_insert(sql, _pk, binds, _returning)
|
||||
[sql, binds]
|
||||
end
|
||||
|
||||
@ -634,6 +641,10 @@ def last_inserted_id(result)
|
||||
single_value_from_rows(result.rows)
|
||||
end
|
||||
|
||||
def returning_column_values(result)
|
||||
[last_inserted_id(result)]
|
||||
end
|
||||
|
||||
def single_value_from_rows(rows)
|
||||
row = rows.first
|
||||
row && row.first
|
||||
|
@ -106,8 +106,9 @@ def index_exists?(table_name, column_name, **options)
|
||||
# Returns an array of +Column+ objects for the table specified by +table_name+.
|
||||
def columns(table_name)
|
||||
table_name = table_name.to_s
|
||||
column_definitions(table_name).map do |field|
|
||||
new_column_from_field(table_name, field)
|
||||
definitions = column_definitions(table_name)
|
||||
definitions.map do |field|
|
||||
new_column_from_field(table_name, field, definitions)
|
||||
end
|
||||
end
|
||||
|
||||
|
@ -580,6 +580,10 @@ def supports_concurrent_connections?
|
||||
true
|
||||
end
|
||||
|
||||
def return_value_after_insert?(column) # :nodoc:
|
||||
column.auto_incremented_by_db?
|
||||
end
|
||||
|
||||
def async_enabled? # :nodoc:
|
||||
supports_concurrent_connections? &&
|
||||
!ActiveRecord.async_query_executor.nil? && !pool.async_executor.nil?
|
||||
|
@ -63,6 +63,15 @@ def encode_with(coder)
|
||||
coder["comment"] = @comment
|
||||
end
|
||||
|
||||
# whether the column is auto-populated by the database using a sequence
|
||||
def auto_incremented_by_db?
|
||||
false
|
||||
end
|
||||
|
||||
def auto_populated?
|
||||
auto_incremented_by_db? || default_function
|
||||
end
|
||||
|
||||
def ==(other)
|
||||
other.is_a?(Column) &&
|
||||
name == other.name &&
|
||||
|
@ -17,6 +17,7 @@ def case_sensitive?
|
||||
def auto_increment?
|
||||
extra == "auto_increment"
|
||||
end
|
||||
alias_method :auto_incremented_by_db?, :auto_increment?
|
||||
|
||||
def virtual?
|
||||
/\b(?:VIRTUAL|STORED|PERSISTENT)\b/.match?(extra)
|
||||
|
@ -175,7 +175,7 @@ def default_type(table_name, field_name)
|
||||
end
|
||||
end
|
||||
|
||||
def new_column_from_field(table_name, field)
|
||||
def new_column_from_field(table_name, field, _definitions)
|
||||
field_name = field.fetch(:Field)
|
||||
type_metadata = fetch_type_metadata(field[:Type], field[:Extra])
|
||||
default, default_function = field[:Default], nil
|
||||
|
@ -15,6 +15,7 @@ def initialize(*, serial: nil, generated: nil, **)
|
||||
def serial?
|
||||
@serial
|
||||
end
|
||||
alias_method :auto_incremented_by_db?, :serial?
|
||||
|
||||
def virtual?
|
||||
# We assume every generated column is virtual, no matter the concrete type
|
||||
|
@ -75,22 +75,23 @@ def exec_delete(sql, name = nil, binds = []) # :nodoc:
|
||||
end
|
||||
alias :exec_update :exec_delete
|
||||
|
||||
def sql_for_insert(sql, pk, binds) # :nodoc:
|
||||
def sql_for_insert(sql, pk, binds, returning) # :nodoc:
|
||||
if pk.nil?
|
||||
# Extract the table from the insert sql. Yuck.
|
||||
table_ref = extract_table_ref_from_insert_sql(sql)
|
||||
pk = primary_key(table_ref) if table_ref
|
||||
end
|
||||
|
||||
if pk = suppress_composite_primary_key(pk)
|
||||
sql = "#{sql} RETURNING #{quote_column_name(pk)}"
|
||||
end
|
||||
returning_columns = returning || Array(pk)
|
||||
|
||||
returning_columns_statement = returning_columns.map { |c| quote_column_name(c) }.join(", ")
|
||||
sql = "#{sql} RETURNING #{returning_columns_statement}" if returning_columns.any?
|
||||
|
||||
super
|
||||
end
|
||||
private :sql_for_insert
|
||||
|
||||
def exec_insert(sql, name = nil, binds = [], pk = nil, sequence_name = nil) # :nodoc:
|
||||
def exec_insert(sql, name = nil, binds = [], pk = nil, sequence_name = nil, returning: nil) # :nodoc:
|
||||
if use_insert_returning? || pk == false
|
||||
super
|
||||
else
|
||||
@ -172,6 +173,10 @@ def last_insert_id_result(sequence_name)
|
||||
internal_exec_query("SELECT currval(#{quote(sequence_name)})", "SQL")
|
||||
end
|
||||
|
||||
def returning_column_values(result)
|
||||
result.rows.first
|
||||
end
|
||||
|
||||
def suppress_composite_primary_key(pk)
|
||||
pk unless pk.is_a?(Array)
|
||||
end
|
||||
|
@ -897,7 +897,7 @@ def create_alter_table(name)
|
||||
PostgreSQL::AlterTable.new create_table_definition(name)
|
||||
end
|
||||
|
||||
def new_column_from_field(table_name, field)
|
||||
def new_column_from_field(table_name, field, _definitions)
|
||||
column_name, type, default, notnull, oid, fmod, collation, comment, attgenerated = field
|
||||
type_metadata = fetch_type_metadata(column_name, type, oid.to_i, fmod.to_i)
|
||||
default_value = extract_value_from_default(default)
|
||||
|
@ -281,6 +281,10 @@ def index_algorithms
|
||||
{ concurrently: "CONCURRENTLY" }
|
||||
end
|
||||
|
||||
def return_value_after_insert?(column) # :nodoc:
|
||||
column.auto_populated?
|
||||
end
|
||||
|
||||
class StatementPool < ConnectionAdapters::StatementPool # :nodoc:
|
||||
def initialize(connection, max)
|
||||
super(max)
|
||||
|
@ -4,15 +4,22 @@ module ActiveRecord
|
||||
module ConnectionAdapters
|
||||
module SQLite3
|
||||
class Column < ConnectionAdapters::Column # :nodoc:
|
||||
def initialize(*, auto_increment: nil, **)
|
||||
attr_reader :rowid
|
||||
|
||||
def initialize(*, auto_increment: nil, rowid: false, **)
|
||||
super
|
||||
@auto_increment = auto_increment
|
||||
@rowid = rowid
|
||||
end
|
||||
|
||||
def auto_increment?
|
||||
@auto_increment
|
||||
end
|
||||
|
||||
def auto_incremented_by_db?
|
||||
auto_increment? || rowid
|
||||
end
|
||||
|
||||
def init_with(coder)
|
||||
@auto_increment = coder["auto_increment"]
|
||||
super
|
||||
@ -33,7 +40,8 @@ def ==(other)
|
||||
def hash
|
||||
Column.hash ^
|
||||
super.hash ^
|
||||
auto_increment?.hash
|
||||
auto_increment?.hash ^
|
||||
rowid.hash
|
||||
end
|
||||
end
|
||||
end
|
||||
|
@ -132,12 +132,13 @@ def validate_index_length!(table_name, new_name, internal = false)
|
||||
super unless internal
|
||||
end
|
||||
|
||||
def new_column_from_field(table_name, field)
|
||||
def new_column_from_field(table_name, field, definitions)
|
||||
default = field["dflt_value"]
|
||||
|
||||
type_metadata = fetch_type_metadata(field["type"])
|
||||
default_value = extract_value_from_default(default)
|
||||
default_function = extract_default_function(default_value, default)
|
||||
rowid = is_column_the_rowid?(field, definitions)
|
||||
|
||||
Column.new(
|
||||
field["name"],
|
||||
@ -147,9 +148,20 @@ def new_column_from_field(table_name, field)
|
||||
default_function,
|
||||
collation: field["collation"],
|
||||
auto_increment: field["auto_increment"],
|
||||
rowid: rowid
|
||||
)
|
||||
end
|
||||
|
||||
INTEGER_REGEX = /integer/i
|
||||
# if a rowid table has a primary key that consists of a single column
|
||||
# and the declared type of that column is "INTEGER" in any mixture of upper and lower case,
|
||||
# then the column becomes an alias for the rowid.
|
||||
def is_column_the_rowid?(field, column_definitions)
|
||||
return false unless INTEGER_REGEX.match?(field["type"]) && field["pk"] == 1
|
||||
# is the primary key a single column?
|
||||
column_definitions.one? { |c| c["pk"] > 0 }
|
||||
end
|
||||
|
||||
def data_source_sql(name = nil, type: nil)
|
||||
scope = quoted_scope(name, type: type)
|
||||
scope[:type] ||= "'table','view'"
|
||||
|
@ -21,7 +21,7 @@ def internal_exec_query(sql, name = "SQL", binds = [], prepare: false, async: fa
|
||||
ActiveRecord::Result.new(result.fields, result.to_a)
|
||||
end
|
||||
|
||||
def exec_insert(sql, name, binds, pk = nil, sequence_name = nil) # :nodoc:
|
||||
def exec_insert(sql, name, binds, pk = nil, sequence_name = nil, returning: nil) # :nodoc:
|
||||
sql = transform_query(sql)
|
||||
check_if_write_query(sql)
|
||||
mark_transaction_written_if_write(sql)
|
||||
|
@ -422,6 +422,12 @@ def columns
|
||||
@columns ||= columns_hash.values.freeze
|
||||
end
|
||||
|
||||
def _returning_columns_for_insert # :nodoc:
|
||||
@_returning_columns_for_insert ||= columns.filter_map do |c|
|
||||
c.name if connection.return_value_after_insert?(c)
|
||||
end
|
||||
end
|
||||
|
||||
def attribute_types # :nodoc:
|
||||
load_schema
|
||||
@attribute_types ||= Hash.new(Type.default_value)
|
||||
@ -546,6 +552,7 @@ def initialize_load_schema_monitor
|
||||
end
|
||||
|
||||
def reload_schema_from_cache(recursive = true)
|
||||
@_returning_columns_for_insert = nil
|
||||
@arel_table = nil
|
||||
@column_names = nil
|
||||
@symbol_column_to_string_name_hash = nil
|
||||
|
@ -561,7 +561,7 @@ def delete(id_or_array)
|
||||
delete_by(primary_key => id_or_array)
|
||||
end
|
||||
|
||||
def _insert_record(values) # :nodoc:
|
||||
def _insert_record(values, returning) # :nodoc:
|
||||
primary_key = self.primary_key
|
||||
primary_key_value = nil
|
||||
|
||||
@ -580,7 +580,10 @@ def _insert_record(values) # :nodoc:
|
||||
im.insert(values.transform_keys { |name| arel_table[name] })
|
||||
end
|
||||
|
||||
connection.insert(im, "#{self} Create", primary_key || false, primary_key_value)
|
||||
connection.insert(
|
||||
im, "#{self} Create", primary_key || false, primary_key_value,
|
||||
returning: returning
|
||||
)
|
||||
end
|
||||
|
||||
def _update_record(values, constraints) # :nodoc:
|
||||
@ -1235,11 +1238,16 @@ def _update_record(attribute_names = self.attribute_names)
|
||||
def _create_record(attribute_names = self.attribute_names)
|
||||
attribute_names = attributes_for_create(attribute_names)
|
||||
|
||||
new_id = self.class._insert_record(
|
||||
attributes_with_values(attribute_names)
|
||||
returning_columns = self.class._returning_columns_for_insert
|
||||
|
||||
returning_values = self.class._insert_record(
|
||||
attributes_with_values(attribute_names),
|
||||
returning_columns
|
||||
)
|
||||
|
||||
self.id ||= new_id if @primary_key
|
||||
returning_columns.zip(returning_values).each do |column, value|
|
||||
_write_attribute(column, value) if !_read_attribute(column)
|
||||
end if returning_values
|
||||
|
||||
@new_record = false
|
||||
@previously_new_record = true
|
||||
|
@ -39,6 +39,7 @@ class UUIDType < ActiveRecord::Base
|
||||
end
|
||||
|
||||
teardown do
|
||||
UUIDType.reset_column_information
|
||||
drop_table "uuid_data_type"
|
||||
end
|
||||
|
||||
|
@ -710,6 +710,42 @@ def test_strict_strings_by_default_and_false_in_database_yml
|
||||
end
|
||||
end
|
||||
|
||||
def test_rowid_column
|
||||
with_example_table "id_uppercase INTEGER PRIMARY KEY" do
|
||||
assert @conn.columns("ex").index_by(&:name)["id_uppercase"].rowid
|
||||
end
|
||||
end
|
||||
|
||||
def test_lowercase_rowid_column
|
||||
with_example_table "id_lowercase integer PRIMARY KEY" do
|
||||
assert @conn.columns("ex").index_by(&:name)["id_lowercase"].rowid
|
||||
end
|
||||
end
|
||||
|
||||
def test_non_integer_column_returns_false_for_rowid
|
||||
with_example_table "id_int_short int PRIMARY KEY" do
|
||||
assert_not @conn.columns("ex").index_by(&:name)["id_int_short"].rowid
|
||||
end
|
||||
end
|
||||
|
||||
def test_mixed_case_integer_colum_returns_true_for_rowid
|
||||
with_example_table "id_mixed_case InTeGeR PRIMARY KEY" do
|
||||
assert @conn.columns("ex").index_by(&:name)["id_mixed_case"].rowid
|
||||
end
|
||||
end
|
||||
|
||||
def test_rowid_column_with_autoincrement_returns_true_for_rowid
|
||||
with_example_table "id_autoincrement integer PRIMARY KEY AUTOINCREMENT" do
|
||||
assert @conn.columns("ex").index_by(&:name)["id_autoincrement"].rowid
|
||||
end
|
||||
end
|
||||
|
||||
def test_integer_cpk_column_returns_false_for_rowid
|
||||
with_example_table("id integer, shop_id integer, PRIMARY KEY (shop_id, id)", "cpk_table") do
|
||||
assert_not @conn.columns("cpk_table").any?(&:rowid)
|
||||
end
|
||||
end
|
||||
|
||||
private
|
||||
def assert_logged(logs)
|
||||
subscriber = SQLSubscriber.new
|
||||
|
@ -23,11 +23,38 @@
|
||||
require "models/admin/user"
|
||||
require "models/cpk"
|
||||
require "models/chat_message"
|
||||
require "models/default"
|
||||
|
||||
class PersistenceTest < ActiveRecord::TestCase
|
||||
fixtures :topics, :companies, :developers, :accounts, :minimalistics, :authors, :author_addresses,
|
||||
:posts, :minivans, :clothing_items, :cpk_books
|
||||
|
||||
def test_populates_non_primary_key_autoincremented_column
|
||||
topic = TitlePrimaryKeyTopic.create!(title: "title pk topic")
|
||||
|
||||
assert_not_nil topic.attributes["id"]
|
||||
end
|
||||
|
||||
def test_populates_non_primary_key_autoincremented_column_for_a_cpk_model
|
||||
order = Cpk::Order.create(shop_id: 111_222)
|
||||
|
||||
_shop_id, order_id = order.id
|
||||
|
||||
assert_not_nil order_id
|
||||
end
|
||||
|
||||
def test_fills_auto_populated_columns_on_creation
|
||||
record_with_defaults = Default.create
|
||||
assert_not_nil record_with_defaults.id
|
||||
assert_equal "Ruby on Rails", record_with_defaults.ruby_on_rails
|
||||
assert_not_nil record_with_defaults.rand_number
|
||||
assert_not_nil record_with_defaults.modified_date
|
||||
assert_not_nil record_with_defaults.modified_date_function
|
||||
assert_not_nil record_with_defaults.modified_time
|
||||
assert_not_nil record_with_defaults.modified_time_without_precision
|
||||
assert_not_nil record_with_defaults.modified_time_function
|
||||
end if current_adapter?(:PostgreSQLAdapter)
|
||||
|
||||
def test_update_many
|
||||
topic_data = { 1 => { "content" => "1 updated" }, 2 => { "content" => "2 updated" } }
|
||||
updated = Topic.update(topic_data.keys, topic_data.values)
|
||||
|
@ -25,6 +25,8 @@
|
||||
end
|
||||
|
||||
create_table :defaults, force: true do |t|
|
||||
t.integer :rand_number, default: -> { "random() * 100" }
|
||||
t.string :ruby_on_rails, default: -> { "concat('Ruby ', 'on ', 'Rails')" }
|
||||
t.date :modified_date, default: -> { "CURRENT_DATE" }
|
||||
t.date :modified_date_function, default: -> { "now()" }
|
||||
t.date :fixed_date, default: "2004-01-01"
|
||||
|
@ -259,9 +259,10 @@
|
||||
t.string :comment
|
||||
end
|
||||
|
||||
create_table :cpk_orders, primary_key: [:shop_id, :id], force: true do |t|
|
||||
# not a composite primary key on the db level to get autoincrement behavior for `id` column
|
||||
# composite primary key is configured on the model level
|
||||
create_table :cpk_orders, force: true do |t|
|
||||
t.integer :shop_id
|
||||
t.integer :id
|
||||
t.string :status
|
||||
end
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user