Assign auto populated columns on Active Record object creation

This commit extends Active Record creation logic to allow for a database
auto-populated attributes to be assigned on object creation.

Given a `Post` model represented by the following schema:
```ruby
create_table :posts, id: false do |t|
  t.integer :sequential_number, auto_increment: true
  t.string :title, primary_key: true
  t.string :ruby_on_rails, default: -> { "concat('R', 'o', 'R')" }
end
```
where `title` is being used as a primary key, the table has an
integer `sequential_number` column populated by a sequence and
`ruby_on_rails` column has a default function - creation of
`Post` records should populate the `sequential_number` and
`ruby_on_rails` attributes:

```ruby
new_post = Post.create(title: 'My first post')
new_post.sequential_number # => 1
new_post.ruby_on_rails # => 'RoR'
```

* At this moment MySQL and SQLite adapters are limited to only one
column being populated and the column must be the `auto_increment`
while PostgreSQL adapter supports any number of auto-populated
columns through `RETURNING` statement.
This commit is contained in:
Nikita Vasilevsky 2023-04-21 18:52:45 +00:00
parent d026af43c3
commit c92933265e
No known key found for this signature in database
GPG Key ID: 0FF5725CD31059E4
21 changed files with 173 additions and 26 deletions

@ -1,3 +1,12 @@
* Assign auto populated columns on Active Record record creation
Changes record creation logic to allow for the `auto_increment` column to be assigned
right after creation regardless of it's relation to model's primary key.
PostgreSQL adapter benefits the most from the change allowing for any number of auto-populated
columns to be assigned on the object immediately after row insertion utilizing the `RETURNING` statement.
*Nikita Vasilevsky*
* Use the first key in the `shards` hash from `connected_to` for the `default_shard`.
Some applications may not want to use `:default` as a shard name in their connection model. Unfortunately Active Record expects there to be a `:default` shard because it must assume a shard to get the right connection from the pool manager. Rather than force applications to manually set this, `connects_to` can infer the default shard name from the hash of shards and will now assume that the first shard is your default.

@ -145,8 +145,11 @@ def exec_query(sql, name = "SQL", binds = [], prepare: false)
# Executes insert +sql+ statement in the context of this connection using
# +binds+ as the bind substitutes. +name+ is logged along with
# the executed +sql+ statement.
def exec_insert(sql, name = nil, binds = [], pk = nil, sequence_name = nil)
sql, binds = sql_for_insert(sql, pk, binds)
# Some adapters support the `returning` keyword argument which allows to control the result of the query:
# `nil` is the default value and maintains default behavior. If an array of column names is passed -
# the result will contain values of the specified columns from the inserted row.
def exec_insert(sql, name = nil, binds = [], pk = nil, sequence_name = nil, returning: nil)
sql, binds = sql_for_insert(sql, pk, binds, returning)
internal_exec_query(sql, name, binds)
end
@ -180,10 +183,14 @@ def explain(arel, binds = [], options = []) # :nodoc:
#
# If the next id was calculated in advance (as in Oracle), it should be
# passed in as +id_value+.
def insert(arel, name = nil, pk = nil, id_value = nil, sequence_name = nil, binds = [])
# Some adapters support the `returning` keyword argument which allows defining the return value of the method:
# `nil` is the default value and maintains default behavior. If an array of column names is passed -
# an array of is returned from the method representing values of the specified columns from the inserted row.
def insert(arel, name = nil, pk = nil, id_value = nil, sequence_name = nil, binds = [], returning: nil)
sql, binds = to_sql_and_binds(arel, binds)
value = exec_insert(sql, name, binds, pk, sequence_name)
id_value || last_inserted_id(value)
value = exec_insert(sql, name, binds, pk, sequence_name, returning: returning)
return id_value if id_value
returning.nil? ? last_inserted_id(value) : returning_column_values(value)
end
alias create insert
@ -626,7 +633,7 @@ def select(sql, name = nil, binds = [], prepare: false, async: false)
end
end
def sql_for_insert(sql, pk, binds)
def sql_for_insert(sql, _pk, binds, _returning)
[sql, binds]
end
@ -634,6 +641,10 @@ def last_inserted_id(result)
single_value_from_rows(result.rows)
end
def returning_column_values(result)
[last_inserted_id(result)]
end
def single_value_from_rows(rows)
row = rows.first
row && row.first

@ -106,8 +106,9 @@ def index_exists?(table_name, column_name, **options)
# Returns an array of +Column+ objects for the table specified by +table_name+.
def columns(table_name)
table_name = table_name.to_s
column_definitions(table_name).map do |field|
new_column_from_field(table_name, field)
definitions = column_definitions(table_name)
definitions.map do |field|
new_column_from_field(table_name, field, definitions)
end
end

@ -580,6 +580,10 @@ def supports_concurrent_connections?
true
end
def return_value_after_insert?(column) # :nodoc:
column.auto_incremented_by_db?
end
def async_enabled? # :nodoc:
supports_concurrent_connections? &&
!ActiveRecord.async_query_executor.nil? && !pool.async_executor.nil?

@ -63,6 +63,15 @@ def encode_with(coder)
coder["comment"] = @comment
end
# whether the column is auto-populated by the database using a sequence
def auto_incremented_by_db?
false
end
def auto_populated?
auto_incremented_by_db? || default_function
end
def ==(other)
other.is_a?(Column) &&
name == other.name &&

@ -17,6 +17,7 @@ def case_sensitive?
def auto_increment?
extra == "auto_increment"
end
alias_method :auto_incremented_by_db?, :auto_increment?
def virtual?
/\b(?:VIRTUAL|STORED|PERSISTENT)\b/.match?(extra)

@ -175,7 +175,7 @@ def default_type(table_name, field_name)
end
end
def new_column_from_field(table_name, field)
def new_column_from_field(table_name, field, _definitions)
field_name = field.fetch(:Field)
type_metadata = fetch_type_metadata(field[:Type], field[:Extra])
default, default_function = field[:Default], nil

@ -15,6 +15,7 @@ def initialize(*, serial: nil, generated: nil, **)
def serial?
@serial
end
alias_method :auto_incremented_by_db?, :serial?
def virtual?
# We assume every generated column is virtual, no matter the concrete type

@ -75,22 +75,23 @@ def exec_delete(sql, name = nil, binds = []) # :nodoc:
end
alias :exec_update :exec_delete
def sql_for_insert(sql, pk, binds) # :nodoc:
def sql_for_insert(sql, pk, binds, returning) # :nodoc:
if pk.nil?
# Extract the table from the insert sql. Yuck.
table_ref = extract_table_ref_from_insert_sql(sql)
pk = primary_key(table_ref) if table_ref
end
if pk = suppress_composite_primary_key(pk)
sql = "#{sql} RETURNING #{quote_column_name(pk)}"
end
returning_columns = returning || Array(pk)
returning_columns_statement = returning_columns.map { |c| quote_column_name(c) }.join(", ")
sql = "#{sql} RETURNING #{returning_columns_statement}" if returning_columns.any?
super
end
private :sql_for_insert
def exec_insert(sql, name = nil, binds = [], pk = nil, sequence_name = nil) # :nodoc:
def exec_insert(sql, name = nil, binds = [], pk = nil, sequence_name = nil, returning: nil) # :nodoc:
if use_insert_returning? || pk == false
super
else
@ -172,6 +173,10 @@ def last_insert_id_result(sequence_name)
internal_exec_query("SELECT currval(#{quote(sequence_name)})", "SQL")
end
def returning_column_values(result)
result.rows.first
end
def suppress_composite_primary_key(pk)
pk unless pk.is_a?(Array)
end

@ -897,7 +897,7 @@ def create_alter_table(name)
PostgreSQL::AlterTable.new create_table_definition(name)
end
def new_column_from_field(table_name, field)
def new_column_from_field(table_name, field, _definitions)
column_name, type, default, notnull, oid, fmod, collation, comment, attgenerated = field
type_metadata = fetch_type_metadata(column_name, type, oid.to_i, fmod.to_i)
default_value = extract_value_from_default(default)

@ -281,6 +281,10 @@ def index_algorithms
{ concurrently: "CONCURRENTLY" }
end
def return_value_after_insert?(column) # :nodoc:
column.auto_populated?
end
class StatementPool < ConnectionAdapters::StatementPool # :nodoc:
def initialize(connection, max)
super(max)

@ -4,15 +4,22 @@ module ActiveRecord
module ConnectionAdapters
module SQLite3
class Column < ConnectionAdapters::Column # :nodoc:
def initialize(*, auto_increment: nil, **)
attr_reader :rowid
def initialize(*, auto_increment: nil, rowid: false, **)
super
@auto_increment = auto_increment
@rowid = rowid
end
def auto_increment?
@auto_increment
end
def auto_incremented_by_db?
auto_increment? || rowid
end
def init_with(coder)
@auto_increment = coder["auto_increment"]
super
@ -33,7 +40,8 @@ def ==(other)
def hash
Column.hash ^
super.hash ^
auto_increment?.hash
auto_increment?.hash ^
rowid.hash
end
end
end

@ -132,12 +132,13 @@ def validate_index_length!(table_name, new_name, internal = false)
super unless internal
end
def new_column_from_field(table_name, field)
def new_column_from_field(table_name, field, definitions)
default = field["dflt_value"]
type_metadata = fetch_type_metadata(field["type"])
default_value = extract_value_from_default(default)
default_function = extract_default_function(default_value, default)
rowid = is_column_the_rowid?(field, definitions)
Column.new(
field["name"],
@ -147,9 +148,20 @@ def new_column_from_field(table_name, field)
default_function,
collation: field["collation"],
auto_increment: field["auto_increment"],
rowid: rowid
)
end
INTEGER_REGEX = /integer/i
# if a rowid table has a primary key that consists of a single column
# and the declared type of that column is "INTEGER" in any mixture of upper and lower case,
# then the column becomes an alias for the rowid.
def is_column_the_rowid?(field, column_definitions)
return false unless INTEGER_REGEX.match?(field["type"]) && field["pk"] == 1
# is the primary key a single column?
column_definitions.one? { |c| c["pk"] > 0 }
end
def data_source_sql(name = nil, type: nil)
scope = quoted_scope(name, type: type)
scope[:type] ||= "'table','view'"

@ -21,7 +21,7 @@ def internal_exec_query(sql, name = "SQL", binds = [], prepare: false, async: fa
ActiveRecord::Result.new(result.fields, result.to_a)
end
def exec_insert(sql, name, binds, pk = nil, sequence_name = nil) # :nodoc:
def exec_insert(sql, name, binds, pk = nil, sequence_name = nil, returning: nil) # :nodoc:
sql = transform_query(sql)
check_if_write_query(sql)
mark_transaction_written_if_write(sql)

@ -422,6 +422,12 @@ def columns
@columns ||= columns_hash.values.freeze
end
def _returning_columns_for_insert # :nodoc:
@_returning_columns_for_insert ||= columns.filter_map do |c|
c.name if connection.return_value_after_insert?(c)
end
end
def attribute_types # :nodoc:
load_schema
@attribute_types ||= Hash.new(Type.default_value)
@ -546,6 +552,7 @@ def initialize_load_schema_monitor
end
def reload_schema_from_cache(recursive = true)
@_returning_columns_for_insert = nil
@arel_table = nil
@column_names = nil
@symbol_column_to_string_name_hash = nil

@ -561,7 +561,7 @@ def delete(id_or_array)
delete_by(primary_key => id_or_array)
end
def _insert_record(values) # :nodoc:
def _insert_record(values, returning) # :nodoc:
primary_key = self.primary_key
primary_key_value = nil
@ -580,7 +580,10 @@ def _insert_record(values) # :nodoc:
im.insert(values.transform_keys { |name| arel_table[name] })
end
connection.insert(im, "#{self} Create", primary_key || false, primary_key_value)
connection.insert(
im, "#{self} Create", primary_key || false, primary_key_value,
returning: returning
)
end
def _update_record(values, constraints) # :nodoc:
@ -1235,11 +1238,16 @@ def _update_record(attribute_names = self.attribute_names)
def _create_record(attribute_names = self.attribute_names)
attribute_names = attributes_for_create(attribute_names)
new_id = self.class._insert_record(
attributes_with_values(attribute_names)
returning_columns = self.class._returning_columns_for_insert
returning_values = self.class._insert_record(
attributes_with_values(attribute_names),
returning_columns
)
self.id ||= new_id if @primary_key
returning_columns.zip(returning_values).each do |column, value|
_write_attribute(column, value) if !_read_attribute(column)
end if returning_values
@new_record = false
@previously_new_record = true

@ -39,6 +39,7 @@ class UUIDType < ActiveRecord::Base
end
teardown do
UUIDType.reset_column_information
drop_table "uuid_data_type"
end

@ -710,6 +710,42 @@ def test_strict_strings_by_default_and_false_in_database_yml
end
end
def test_rowid_column
with_example_table "id_uppercase INTEGER PRIMARY KEY" do
assert @conn.columns("ex").index_by(&:name)["id_uppercase"].rowid
end
end
def test_lowercase_rowid_column
with_example_table "id_lowercase integer PRIMARY KEY" do
assert @conn.columns("ex").index_by(&:name)["id_lowercase"].rowid
end
end
def test_non_integer_column_returns_false_for_rowid
with_example_table "id_int_short int PRIMARY KEY" do
assert_not @conn.columns("ex").index_by(&:name)["id_int_short"].rowid
end
end
def test_mixed_case_integer_colum_returns_true_for_rowid
with_example_table "id_mixed_case InTeGeR PRIMARY KEY" do
assert @conn.columns("ex").index_by(&:name)["id_mixed_case"].rowid
end
end
def test_rowid_column_with_autoincrement_returns_true_for_rowid
with_example_table "id_autoincrement integer PRIMARY KEY AUTOINCREMENT" do
assert @conn.columns("ex").index_by(&:name)["id_autoincrement"].rowid
end
end
def test_integer_cpk_column_returns_false_for_rowid
with_example_table("id integer, shop_id integer, PRIMARY KEY (shop_id, id)", "cpk_table") do
assert_not @conn.columns("cpk_table").any?(&:rowid)
end
end
private
def assert_logged(logs)
subscriber = SQLSubscriber.new

@ -23,11 +23,38 @@
require "models/admin/user"
require "models/cpk"
require "models/chat_message"
require "models/default"
class PersistenceTest < ActiveRecord::TestCase
fixtures :topics, :companies, :developers, :accounts, :minimalistics, :authors, :author_addresses,
:posts, :minivans, :clothing_items, :cpk_books
def test_populates_non_primary_key_autoincremented_column
topic = TitlePrimaryKeyTopic.create!(title: "title pk topic")
assert_not_nil topic.attributes["id"]
end
def test_populates_non_primary_key_autoincremented_column_for_a_cpk_model
order = Cpk::Order.create(shop_id: 111_222)
_shop_id, order_id = order.id
assert_not_nil order_id
end
def test_fills_auto_populated_columns_on_creation
record_with_defaults = Default.create
assert_not_nil record_with_defaults.id
assert_equal "Ruby on Rails", record_with_defaults.ruby_on_rails
assert_not_nil record_with_defaults.rand_number
assert_not_nil record_with_defaults.modified_date
assert_not_nil record_with_defaults.modified_date_function
assert_not_nil record_with_defaults.modified_time
assert_not_nil record_with_defaults.modified_time_without_precision
assert_not_nil record_with_defaults.modified_time_function
end if current_adapter?(:PostgreSQLAdapter)
def test_update_many
topic_data = { 1 => { "content" => "1 updated" }, 2 => { "content" => "2 updated" } }
updated = Topic.update(topic_data.keys, topic_data.values)

@ -25,6 +25,8 @@
end
create_table :defaults, force: true do |t|
t.integer :rand_number, default: -> { "random() * 100" }
t.string :ruby_on_rails, default: -> { "concat('Ruby ', 'on ', 'Rails')" }
t.date :modified_date, default: -> { "CURRENT_DATE" }
t.date :modified_date_function, default: -> { "now()" }
t.date :fixed_date, default: "2004-01-01"

@ -259,9 +259,10 @@
t.string :comment
end
create_table :cpk_orders, primary_key: [:shop_id, :id], force: true do |t|
# not a composite primary key on the db level to get autoincrement behavior for `id` column
# composite primary key is configured on the model level
create_table :cpk_orders, force: true do |t|
t.integer :shop_id
t.integer :id
t.string :status
end