diff --git a/activerecord/CHANGELOG.md b/activerecord/CHANGELOG.md index e30a0a174a..aafdc7e028 100644 --- a/activerecord/CHANGELOG.md +++ b/activerecord/CHANGELOG.md @@ -1,3 +1,12 @@ +* Assign auto populated columns on Active Record record creation + + Changes record creation logic to allow for the `auto_increment` column to be assigned + right after creation regardless of it's relation to model's primary key. + PostgreSQL adapter benefits the most from the change allowing for any number of auto-populated + columns to be assigned on the object immediately after row insertion utilizing the `RETURNING` statement. + + *Nikita Vasilevsky* + * Use the first key in the `shards` hash from `connected_to` for the `default_shard`. Some applications may not want to use `:default` as a shard name in their connection model. Unfortunately Active Record expects there to be a `:default` shard because it must assume a shard to get the right connection from the pool manager. Rather than force applications to manually set this, `connects_to` can infer the default shard name from the hash of shards and will now assume that the first shard is your default. diff --git a/activerecord/lib/active_record/connection_adapters/abstract/database_statements.rb b/activerecord/lib/active_record/connection_adapters/abstract/database_statements.rb index 25b8d8e24b..b2810112d7 100644 --- a/activerecord/lib/active_record/connection_adapters/abstract/database_statements.rb +++ b/activerecord/lib/active_record/connection_adapters/abstract/database_statements.rb @@ -145,8 +145,11 @@ def exec_query(sql, name = "SQL", binds = [], prepare: false) # Executes insert +sql+ statement in the context of this connection using # +binds+ as the bind substitutes. +name+ is logged along with # the executed +sql+ statement. - def exec_insert(sql, name = nil, binds = [], pk = nil, sequence_name = nil) - sql, binds = sql_for_insert(sql, pk, binds) + # Some adapters support the `returning` keyword argument which allows to control the result of the query: + # `nil` is the default value and maintains default behavior. If an array of column names is passed - + # the result will contain values of the specified columns from the inserted row. + def exec_insert(sql, name = nil, binds = [], pk = nil, sequence_name = nil, returning: nil) + sql, binds = sql_for_insert(sql, pk, binds, returning) internal_exec_query(sql, name, binds) end @@ -180,10 +183,14 @@ def explain(arel, binds = [], options = []) # :nodoc: # # If the next id was calculated in advance (as in Oracle), it should be # passed in as +id_value+. - def insert(arel, name = nil, pk = nil, id_value = nil, sequence_name = nil, binds = []) + # Some adapters support the `returning` keyword argument which allows defining the return value of the method: + # `nil` is the default value and maintains default behavior. If an array of column names is passed - + # an array of is returned from the method representing values of the specified columns from the inserted row. + def insert(arel, name = nil, pk = nil, id_value = nil, sequence_name = nil, binds = [], returning: nil) sql, binds = to_sql_and_binds(arel, binds) - value = exec_insert(sql, name, binds, pk, sequence_name) - id_value || last_inserted_id(value) + value = exec_insert(sql, name, binds, pk, sequence_name, returning: returning) + return id_value if id_value + returning.nil? ? last_inserted_id(value) : returning_column_values(value) end alias create insert @@ -626,7 +633,7 @@ def select(sql, name = nil, binds = [], prepare: false, async: false) end end - def sql_for_insert(sql, pk, binds) + def sql_for_insert(sql, _pk, binds, _returning) [sql, binds] end @@ -634,6 +641,10 @@ def last_inserted_id(result) single_value_from_rows(result.rows) end + def returning_column_values(result) + [last_inserted_id(result)] + end + def single_value_from_rows(rows) row = rows.first row && row.first diff --git a/activerecord/lib/active_record/connection_adapters/abstract/schema_statements.rb b/activerecord/lib/active_record/connection_adapters/abstract/schema_statements.rb index 4a79b1b1b8..b05eade5d2 100644 --- a/activerecord/lib/active_record/connection_adapters/abstract/schema_statements.rb +++ b/activerecord/lib/active_record/connection_adapters/abstract/schema_statements.rb @@ -106,8 +106,9 @@ def index_exists?(table_name, column_name, **options) # Returns an array of +Column+ objects for the table specified by +table_name+. def columns(table_name) table_name = table_name.to_s - column_definitions(table_name).map do |field| - new_column_from_field(table_name, field) + definitions = column_definitions(table_name) + definitions.map do |field| + new_column_from_field(table_name, field, definitions) end end diff --git a/activerecord/lib/active_record/connection_adapters/abstract_adapter.rb b/activerecord/lib/active_record/connection_adapters/abstract_adapter.rb index 7241e73ced..c930688972 100644 --- a/activerecord/lib/active_record/connection_adapters/abstract_adapter.rb +++ b/activerecord/lib/active_record/connection_adapters/abstract_adapter.rb @@ -580,6 +580,10 @@ def supports_concurrent_connections? true end + def return_value_after_insert?(column) # :nodoc: + column.auto_incremented_by_db? + end + def async_enabled? # :nodoc: supports_concurrent_connections? && !ActiveRecord.async_query_executor.nil? && !pool.async_executor.nil? diff --git a/activerecord/lib/active_record/connection_adapters/column.rb b/activerecord/lib/active_record/connection_adapters/column.rb index bfff81b91b..4c05be416a 100644 --- a/activerecord/lib/active_record/connection_adapters/column.rb +++ b/activerecord/lib/active_record/connection_adapters/column.rb @@ -63,6 +63,15 @@ def encode_with(coder) coder["comment"] = @comment end + # whether the column is auto-populated by the database using a sequence + def auto_incremented_by_db? + false + end + + def auto_populated? + auto_incremented_by_db? || default_function + end + def ==(other) other.is_a?(Column) && name == other.name && diff --git a/activerecord/lib/active_record/connection_adapters/mysql/column.rb b/activerecord/lib/active_record/connection_adapters/mysql/column.rb index c21529b0a8..0d4b022548 100644 --- a/activerecord/lib/active_record/connection_adapters/mysql/column.rb +++ b/activerecord/lib/active_record/connection_adapters/mysql/column.rb @@ -17,6 +17,7 @@ def case_sensitive? def auto_increment? extra == "auto_increment" end + alias_method :auto_incremented_by_db?, :auto_increment? def virtual? /\b(?:VIRTUAL|STORED|PERSISTENT)\b/.match?(extra) diff --git a/activerecord/lib/active_record/connection_adapters/mysql/schema_statements.rb b/activerecord/lib/active_record/connection_adapters/mysql/schema_statements.rb index e51e607b2f..b35f3d7add 100644 --- a/activerecord/lib/active_record/connection_adapters/mysql/schema_statements.rb +++ b/activerecord/lib/active_record/connection_adapters/mysql/schema_statements.rb @@ -175,7 +175,7 @@ def default_type(table_name, field_name) end end - def new_column_from_field(table_name, field) + def new_column_from_field(table_name, field, _definitions) field_name = field.fetch(:Field) type_metadata = fetch_type_metadata(field[:Type], field[:Extra]) default, default_function = field[:Default], nil diff --git a/activerecord/lib/active_record/connection_adapters/postgresql/column.rb b/activerecord/lib/active_record/connection_adapters/postgresql/column.rb index 4ed217dd09..7d31d10042 100644 --- a/activerecord/lib/active_record/connection_adapters/postgresql/column.rb +++ b/activerecord/lib/active_record/connection_adapters/postgresql/column.rb @@ -15,6 +15,7 @@ def initialize(*, serial: nil, generated: nil, **) def serial? @serial end + alias_method :auto_incremented_by_db?, :serial? def virtual? # We assume every generated column is virtual, no matter the concrete type diff --git a/activerecord/lib/active_record/connection_adapters/postgresql/database_statements.rb b/activerecord/lib/active_record/connection_adapters/postgresql/database_statements.rb index 7652cad9dd..1cc01a137f 100644 --- a/activerecord/lib/active_record/connection_adapters/postgresql/database_statements.rb +++ b/activerecord/lib/active_record/connection_adapters/postgresql/database_statements.rb @@ -75,22 +75,23 @@ def exec_delete(sql, name = nil, binds = []) # :nodoc: end alias :exec_update :exec_delete - def sql_for_insert(sql, pk, binds) # :nodoc: + def sql_for_insert(sql, pk, binds, returning) # :nodoc: if pk.nil? # Extract the table from the insert sql. Yuck. table_ref = extract_table_ref_from_insert_sql(sql) pk = primary_key(table_ref) if table_ref end - if pk = suppress_composite_primary_key(pk) - sql = "#{sql} RETURNING #{quote_column_name(pk)}" - end + returning_columns = returning || Array(pk) + + returning_columns_statement = returning_columns.map { |c| quote_column_name(c) }.join(", ") + sql = "#{sql} RETURNING #{returning_columns_statement}" if returning_columns.any? super end private :sql_for_insert - def exec_insert(sql, name = nil, binds = [], pk = nil, sequence_name = nil) # :nodoc: + def exec_insert(sql, name = nil, binds = [], pk = nil, sequence_name = nil, returning: nil) # :nodoc: if use_insert_returning? || pk == false super else @@ -172,6 +173,10 @@ def last_insert_id_result(sequence_name) internal_exec_query("SELECT currval(#{quote(sequence_name)})", "SQL") end + def returning_column_values(result) + result.rows.first + end + def suppress_composite_primary_key(pk) pk unless pk.is_a?(Array) end diff --git a/activerecord/lib/active_record/connection_adapters/postgresql/schema_statements.rb b/activerecord/lib/active_record/connection_adapters/postgresql/schema_statements.rb index a93325fe50..906868333a 100644 --- a/activerecord/lib/active_record/connection_adapters/postgresql/schema_statements.rb +++ b/activerecord/lib/active_record/connection_adapters/postgresql/schema_statements.rb @@ -897,7 +897,7 @@ def create_alter_table(name) PostgreSQL::AlterTable.new create_table_definition(name) end - def new_column_from_field(table_name, field) + def new_column_from_field(table_name, field, _definitions) column_name, type, default, notnull, oid, fmod, collation, comment, attgenerated = field type_metadata = fetch_type_metadata(column_name, type, oid.to_i, fmod.to_i) default_value = extract_value_from_default(default) diff --git a/activerecord/lib/active_record/connection_adapters/postgresql_adapter.rb b/activerecord/lib/active_record/connection_adapters/postgresql_adapter.rb index 8d6406fea9..502d52591f 100644 --- a/activerecord/lib/active_record/connection_adapters/postgresql_adapter.rb +++ b/activerecord/lib/active_record/connection_adapters/postgresql_adapter.rb @@ -281,6 +281,10 @@ def index_algorithms { concurrently: "CONCURRENTLY" } end + def return_value_after_insert?(column) # :nodoc: + column.auto_populated? + end + class StatementPool < ConnectionAdapters::StatementPool # :nodoc: def initialize(connection, max) super(max) diff --git a/activerecord/lib/active_record/connection_adapters/sqlite3/column.rb b/activerecord/lib/active_record/connection_adapters/sqlite3/column.rb index 005ee76449..93d3d08685 100644 --- a/activerecord/lib/active_record/connection_adapters/sqlite3/column.rb +++ b/activerecord/lib/active_record/connection_adapters/sqlite3/column.rb @@ -4,15 +4,22 @@ module ActiveRecord module ConnectionAdapters module SQLite3 class Column < ConnectionAdapters::Column # :nodoc: - def initialize(*, auto_increment: nil, **) + attr_reader :rowid + + def initialize(*, auto_increment: nil, rowid: false, **) super @auto_increment = auto_increment + @rowid = rowid end def auto_increment? @auto_increment end + def auto_incremented_by_db? + auto_increment? || rowid + end + def init_with(coder) @auto_increment = coder["auto_increment"] super @@ -33,7 +40,8 @@ def ==(other) def hash Column.hash ^ super.hash ^ - auto_increment?.hash + auto_increment?.hash ^ + rowid.hash end end end diff --git a/activerecord/lib/active_record/connection_adapters/sqlite3/schema_statements.rb b/activerecord/lib/active_record/connection_adapters/sqlite3/schema_statements.rb index 19d1acdc0c..a5b3698cee 100644 --- a/activerecord/lib/active_record/connection_adapters/sqlite3/schema_statements.rb +++ b/activerecord/lib/active_record/connection_adapters/sqlite3/schema_statements.rb @@ -132,12 +132,13 @@ def validate_index_length!(table_name, new_name, internal = false) super unless internal end - def new_column_from_field(table_name, field) + def new_column_from_field(table_name, field, definitions) default = field["dflt_value"] type_metadata = fetch_type_metadata(field["type"]) default_value = extract_value_from_default(default) default_function = extract_default_function(default_value, default) + rowid = is_column_the_rowid?(field, definitions) Column.new( field["name"], @@ -147,9 +148,20 @@ def new_column_from_field(table_name, field) default_function, collation: field["collation"], auto_increment: field["auto_increment"], + rowid: rowid ) end + INTEGER_REGEX = /integer/i + # if a rowid table has a primary key that consists of a single column + # and the declared type of that column is "INTEGER" in any mixture of upper and lower case, + # then the column becomes an alias for the rowid. + def is_column_the_rowid?(field, column_definitions) + return false unless INTEGER_REGEX.match?(field["type"]) && field["pk"] == 1 + # is the primary key a single column? + column_definitions.one? { |c| c["pk"] > 0 } + end + def data_source_sql(name = nil, type: nil) scope = quoted_scope(name, type: type) scope[:type] ||= "'table','view'" diff --git a/activerecord/lib/active_record/connection_adapters/trilogy/database_statements.rb b/activerecord/lib/active_record/connection_adapters/trilogy/database_statements.rb index 874d8b3f60..97b5d8f780 100644 --- a/activerecord/lib/active_record/connection_adapters/trilogy/database_statements.rb +++ b/activerecord/lib/active_record/connection_adapters/trilogy/database_statements.rb @@ -21,7 +21,7 @@ def internal_exec_query(sql, name = "SQL", binds = [], prepare: false, async: fa ActiveRecord::Result.new(result.fields, result.to_a) end - def exec_insert(sql, name, binds, pk = nil, sequence_name = nil) # :nodoc: + def exec_insert(sql, name, binds, pk = nil, sequence_name = nil, returning: nil) # :nodoc: sql = transform_query(sql) check_if_write_query(sql) mark_transaction_written_if_write(sql) diff --git a/activerecord/lib/active_record/model_schema.rb b/activerecord/lib/active_record/model_schema.rb index 8e4022a08b..32fa2f0c0f 100644 --- a/activerecord/lib/active_record/model_schema.rb +++ b/activerecord/lib/active_record/model_schema.rb @@ -422,6 +422,12 @@ def columns @columns ||= columns_hash.values.freeze end + def _returning_columns_for_insert # :nodoc: + @_returning_columns_for_insert ||= columns.filter_map do |c| + c.name if connection.return_value_after_insert?(c) + end + end + def attribute_types # :nodoc: load_schema @attribute_types ||= Hash.new(Type.default_value) @@ -546,6 +552,7 @@ def initialize_load_schema_monitor end def reload_schema_from_cache(recursive = true) + @_returning_columns_for_insert = nil @arel_table = nil @column_names = nil @symbol_column_to_string_name_hash = nil diff --git a/activerecord/lib/active_record/persistence.rb b/activerecord/lib/active_record/persistence.rb index 462f620dd0..3b534d6f07 100644 --- a/activerecord/lib/active_record/persistence.rb +++ b/activerecord/lib/active_record/persistence.rb @@ -561,7 +561,7 @@ def delete(id_or_array) delete_by(primary_key => id_or_array) end - def _insert_record(values) # :nodoc: + def _insert_record(values, returning) # :nodoc: primary_key = self.primary_key primary_key_value = nil @@ -580,7 +580,10 @@ def _insert_record(values) # :nodoc: im.insert(values.transform_keys { |name| arel_table[name] }) end - connection.insert(im, "#{self} Create", primary_key || false, primary_key_value) + connection.insert( + im, "#{self} Create", primary_key || false, primary_key_value, + returning: returning + ) end def _update_record(values, constraints) # :nodoc: @@ -1235,11 +1238,16 @@ def _update_record(attribute_names = self.attribute_names) def _create_record(attribute_names = self.attribute_names) attribute_names = attributes_for_create(attribute_names) - new_id = self.class._insert_record( - attributes_with_values(attribute_names) + returning_columns = self.class._returning_columns_for_insert + + returning_values = self.class._insert_record( + attributes_with_values(attribute_names), + returning_columns ) - self.id ||= new_id if @primary_key + returning_columns.zip(returning_values).each do |column, value| + _write_attribute(column, value) if !_read_attribute(column) + end if returning_values @new_record = false @previously_new_record = true diff --git a/activerecord/test/cases/adapters/postgresql/uuid_test.rb b/activerecord/test/cases/adapters/postgresql/uuid_test.rb index ce4e8db7cc..6fcad53c9c 100644 --- a/activerecord/test/cases/adapters/postgresql/uuid_test.rb +++ b/activerecord/test/cases/adapters/postgresql/uuid_test.rb @@ -39,6 +39,7 @@ class UUIDType < ActiveRecord::Base end teardown do + UUIDType.reset_column_information drop_table "uuid_data_type" end diff --git a/activerecord/test/cases/adapters/sqlite3/sqlite3_adapter_test.rb b/activerecord/test/cases/adapters/sqlite3/sqlite3_adapter_test.rb index 388ca966b4..ea612100a1 100644 --- a/activerecord/test/cases/adapters/sqlite3/sqlite3_adapter_test.rb +++ b/activerecord/test/cases/adapters/sqlite3/sqlite3_adapter_test.rb @@ -710,6 +710,42 @@ def test_strict_strings_by_default_and_false_in_database_yml end end + def test_rowid_column + with_example_table "id_uppercase INTEGER PRIMARY KEY" do + assert @conn.columns("ex").index_by(&:name)["id_uppercase"].rowid + end + end + + def test_lowercase_rowid_column + with_example_table "id_lowercase integer PRIMARY KEY" do + assert @conn.columns("ex").index_by(&:name)["id_lowercase"].rowid + end + end + + def test_non_integer_column_returns_false_for_rowid + with_example_table "id_int_short int PRIMARY KEY" do + assert_not @conn.columns("ex").index_by(&:name)["id_int_short"].rowid + end + end + + def test_mixed_case_integer_colum_returns_true_for_rowid + with_example_table "id_mixed_case InTeGeR PRIMARY KEY" do + assert @conn.columns("ex").index_by(&:name)["id_mixed_case"].rowid + end + end + + def test_rowid_column_with_autoincrement_returns_true_for_rowid + with_example_table "id_autoincrement integer PRIMARY KEY AUTOINCREMENT" do + assert @conn.columns("ex").index_by(&:name)["id_autoincrement"].rowid + end + end + + def test_integer_cpk_column_returns_false_for_rowid + with_example_table("id integer, shop_id integer, PRIMARY KEY (shop_id, id)", "cpk_table") do + assert_not @conn.columns("cpk_table").any?(&:rowid) + end + end + private def assert_logged(logs) subscriber = SQLSubscriber.new diff --git a/activerecord/test/cases/persistence_test.rb b/activerecord/test/cases/persistence_test.rb index 1f341da7ae..0c0340fa99 100644 --- a/activerecord/test/cases/persistence_test.rb +++ b/activerecord/test/cases/persistence_test.rb @@ -23,11 +23,38 @@ require "models/admin/user" require "models/cpk" require "models/chat_message" +require "models/default" class PersistenceTest < ActiveRecord::TestCase fixtures :topics, :companies, :developers, :accounts, :minimalistics, :authors, :author_addresses, :posts, :minivans, :clothing_items, :cpk_books + def test_populates_non_primary_key_autoincremented_column + topic = TitlePrimaryKeyTopic.create!(title: "title pk topic") + + assert_not_nil topic.attributes["id"] + end + + def test_populates_non_primary_key_autoincremented_column_for_a_cpk_model + order = Cpk::Order.create(shop_id: 111_222) + + _shop_id, order_id = order.id + + assert_not_nil order_id + end + + def test_fills_auto_populated_columns_on_creation + record_with_defaults = Default.create + assert_not_nil record_with_defaults.id + assert_equal "Ruby on Rails", record_with_defaults.ruby_on_rails + assert_not_nil record_with_defaults.rand_number + assert_not_nil record_with_defaults.modified_date + assert_not_nil record_with_defaults.modified_date_function + assert_not_nil record_with_defaults.modified_time + assert_not_nil record_with_defaults.modified_time_without_precision + assert_not_nil record_with_defaults.modified_time_function + end if current_adapter?(:PostgreSQLAdapter) + def test_update_many topic_data = { 1 => { "content" => "1 updated" }, 2 => { "content" => "2 updated" } } updated = Topic.update(topic_data.keys, topic_data.values) diff --git a/activerecord/test/schema/postgresql_specific_schema.rb b/activerecord/test/schema/postgresql_specific_schema.rb index 0c85a137ba..978d7b459c 100644 --- a/activerecord/test/schema/postgresql_specific_schema.rb +++ b/activerecord/test/schema/postgresql_specific_schema.rb @@ -25,6 +25,8 @@ end create_table :defaults, force: true do |t| + t.integer :rand_number, default: -> { "random() * 100" } + t.string :ruby_on_rails, default: -> { "concat('Ruby ', 'on ', 'Rails')" } t.date :modified_date, default: -> { "CURRENT_DATE" } t.date :modified_date_function, default: -> { "now()" } t.date :fixed_date, default: "2004-01-01" diff --git a/activerecord/test/schema/schema.rb b/activerecord/test/schema/schema.rb index baaae8b7d3..ad930a75ba 100644 --- a/activerecord/test/schema/schema.rb +++ b/activerecord/test/schema/schema.rb @@ -259,9 +259,10 @@ t.string :comment end - create_table :cpk_orders, primary_key: [:shop_id, :id], force: true do |t| + # not a composite primary key on the db level to get autoincrement behavior for `id` column + # composite primary key is configured on the model level + create_table :cpk_orders, force: true do |t| t.integer :shop_id - t.integer :id t.string :status end