Merge pull request #48112 from adrianna-chang-shopify/ac-shared-mysql-db-statements

Clean up shared DB statements code between Mysql2 and Trilogy
This commit is contained in:
Eileen M. Uchitelle 2023-05-03 11:52:07 -05:00 committed by GitHub
commit 5947ec2d06
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 208 additions and 271 deletions

@ -3,6 +3,7 @@
require "active_record/connection_adapters/abstract_adapter"
require "active_record/connection_adapters/statement_pool"
require "active_record/connection_adapters/mysql/column"
require "active_record/connection_adapters/mysql/database_statements"
require "active_record/connection_adapters/mysql/explain_pretty_printer"
require "active_record/connection_adapters/mysql/quoting"
require "active_record/connection_adapters/mysql/schema_creation"
@ -14,6 +15,7 @@
module ActiveRecord
module ConnectionAdapters
class AbstractMysqlAdapter < AbstractAdapter
include MySQL::DatabaseStatements
include MySQL::Quoting
include MySQL::SchemaStatements

@ -4,31 +4,26 @@ module ActiveRecord
module ConnectionAdapters
module MySQL
module DatabaseStatements
# Returns an ActiveRecord::Result instance.
def select_all(*, **) # :nodoc:
result = nil
with_raw_connection do |conn|
result = if ExplainRegistry.collect? && prepared_statements
unprepared_statement { super }
else
super
end
conn.abandon_results!
end
result
end
READ_QUERY = ActiveRecord::ConnectionAdapters::AbstractAdapter.build_read_query_regexp(
READ_QUERY = AbstractAdapter.build_read_query_regexp(
:desc, :describe, :set, :show, :use
) # :nodoc:
private_constant :READ_QUERY
# https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_current-timestamp
# https://dev.mysql.com/doc/refman/5.7/en/date-and-time-type-syntax.html
HIGH_PRECISION_CURRENT_TIMESTAMP = Arel.sql("CURRENT_TIMESTAMP(6)").freeze # :nodoc:
private_constant :HIGH_PRECISION_CURRENT_TIMESTAMP
def write_query?(sql) # :nodoc:
!READ_QUERY.match?(sql)
rescue ArgumentError # Invalid encoding
!READ_QUERY.match?(sql.b)
end
def high_precision_current_timestamp
HIGH_PRECISION_CURRENT_TIMESTAMP
end
def explain(arel, binds = [], options = [])
sql = build_explain_clause(options) + " " + to_sql(arel, binds)
start = Process.clock_gettime(Process::CLOCK_MONOTONIC)
@ -38,47 +33,6 @@ def explain(arel, binds = [], options = [])
MySQL::ExplainPrettyPrinter.new.pp(result, elapsed)
end
def internal_exec_query(sql, name = "SQL", binds = [], prepare: false, async: false) # :nodoc:
if without_prepared_statement?(binds)
execute_and_free(sql, name, async: async) do |result|
if result
build_result(columns: result.fields, rows: result.to_a)
else
build_result(columns: [], rows: [])
end
end
else
exec_stmt_and_free(sql, name, binds, cache_stmt: prepare, async: async) do |_, result|
if result
build_result(columns: result.fields, rows: result.to_a)
else
build_result(columns: [], rows: [])
end
end
end
end
def exec_delete(sql, name = nil, binds = []) # :nodoc:
if without_prepared_statement?(binds)
with_raw_connection do |conn|
@affected_rows_before_warnings = nil
execute_and_free(sql, name) { @affected_rows_before_warnings || conn.affected_rows }
end
else
exec_stmt_and_free(sql, name, binds) { |stmt| stmt.affected_rows }
end
end
alias :exec_update :exec_delete
# https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_current-timestamp
# https://dev.mysql.com/doc/refman/5.7/en/date-and-time-type-syntax.html
HIGH_PRECISION_CURRENT_TIMESTAMP = Arel.sql("CURRENT_TIMESTAMP(6)").freeze # :nodoc:
private_constant :HIGH_PRECISION_CURRENT_TIMESTAMP
def high_precision_current_timestamp
HIGH_PRECISION_CURRENT_TIMESTAMP
end
def build_explain_clause(options = [])
return "EXPLAIN" if options.empty?
@ -92,52 +46,15 @@ def build_explain_clause(options = [])
end
private
def sync_timezone_changes(raw_connection)
raw_connection.query_options[:database_timezone] = default_timezone
end
def execute_batch(statements, name = nil)
statements = statements.map { |sql| transform_query(sql) }
combine_multi_statements(statements).each do |statement|
with_raw_connection do |conn|
raw_execute(statement, name)
conn.abandon_results!
end
end
# https://mariadb.com/kb/en/analyze-statement/
def analyze_without_explain?
mariadb? && database_version >= "10.1.0"
end
def default_insert_value(column)
super unless column.auto_increment?
end
def last_inserted_id(result)
@raw_connection&.last_id
end
def multi_statements_enabled?
flags = @config[:flags]
if flags.is_a?(Array)
flags.include?("MULTI_STATEMENTS")
else
flags.anybits?(Mysql2::Client::MULTI_STATEMENTS)
end
end
def with_multi_statements
if multi_statements_enabled?
return yield
end
with_raw_connection do |conn|
conn.set_server_option(Mysql2::Client::OPTION_MULTI_STATEMENTS_ON)
yield
ensure
conn.set_server_option(Mysql2::Client::OPTION_MULTI_STATEMENTS_OFF)
end
end
def combine_multi_statements(total_sql)
total_sql.each_with_object([]) do |sql, total_sql_chunks|
previous_packet = total_sql_chunks.last
@ -164,61 +81,6 @@ def max_allowed_packet_reached?(current_packet, previous_packet)
def max_allowed_packet
@max_allowed_packet ||= show_variable("max_allowed_packet")
end
def raw_execute(sql, name, async: false, allow_retry: false, materialize_transactions: true)
log(sql, name, async: async) do
with_raw_connection(allow_retry: allow_retry, materialize_transactions: materialize_transactions) do |conn|
sync_timezone_changes(conn)
result = conn.query(sql)
handle_warnings(sql)
result
end
end
end
def exec_stmt_and_free(sql, name, binds, cache_stmt: false, async: false)
sql = transform_query(sql)
check_if_write_query(sql)
mark_transaction_written_if_write(sql)
type_casted_binds = type_casted_binds(binds)
log(sql, name, binds, type_casted_binds, async: async) do
with_raw_connection do |conn|
sync_timezone_changes(conn)
if cache_stmt
stmt = @statements[sql] ||= conn.prepare(sql)
else
stmt = conn.prepare(sql)
end
begin
result = ActiveSupport::Dependencies.interlock.permit_concurrent_loads do
stmt.execute(*type_casted_binds)
end
rescue Mysql2::Error => e
if cache_stmt
@statements.delete(sql)
else
stmt.close
end
raise e
end
ret = yield stmt, result
result.free if result
stmt.close unless cache_stmt
ret
end
end
end
# https://mariadb.com/kb/en/analyze-statement/
def analyze_without_explain?
mariadb? && database_version >= "10.1.0"
end
end
end
end

@ -0,0 +1,148 @@
# frozen_string_literal: true
module ActiveRecord
module ConnectionAdapters
module Mysql2
module DatabaseStatements
# Returns an ActiveRecord::Result instance.
def select_all(*, **) # :nodoc:
result = nil
with_raw_connection do |conn|
result = if ExplainRegistry.collect? && prepared_statements
unprepared_statement { super }
else
super
end
conn.abandon_results!
end
result
end
def internal_exec_query(sql, name = "SQL", binds = [], prepare: false, async: false) # :nodoc:
if without_prepared_statement?(binds)
execute_and_free(sql, name, async: async) do |result|
if result
build_result(columns: result.fields, rows: result.to_a)
else
build_result(columns: [], rows: [])
end
end
else
exec_stmt_and_free(sql, name, binds, cache_stmt: prepare, async: async) do |_, result|
if result
build_result(columns: result.fields, rows: result.to_a)
else
build_result(columns: [], rows: [])
end
end
end
end
def exec_delete(sql, name = nil, binds = []) # :nodoc:
if without_prepared_statement?(binds)
with_raw_connection do |conn|
@affected_rows_before_warnings = nil
execute_and_free(sql, name) { @affected_rows_before_warnings || conn.affected_rows }
end
else
exec_stmt_and_free(sql, name, binds) { |stmt| stmt.affected_rows }
end
end
alias :exec_update :exec_delete
private
def sync_timezone_changes(raw_connection)
raw_connection.query_options[:database_timezone] = default_timezone
end
def execute_batch(statements, name = nil)
statements = statements.map { |sql| transform_query(sql) }
combine_multi_statements(statements).each do |statement|
with_raw_connection do |conn|
raw_execute(statement, name)
conn.abandon_results!
end
end
end
def last_inserted_id(result)
@raw_connection&.last_id
end
def multi_statements_enabled?
flags = @config[:flags]
if flags.is_a?(Array)
flags.include?("MULTI_STATEMENTS")
else
flags.anybits?(::Mysql2::Client::MULTI_STATEMENTS)
end
end
def with_multi_statements
if multi_statements_enabled?
return yield
end
with_raw_connection do |conn|
conn.set_server_option(::Mysql2::Client::OPTION_MULTI_STATEMENTS_ON)
yield
ensure
conn.set_server_option(::Mysql2::Client::OPTION_MULTI_STATEMENTS_OFF)
end
end
def raw_execute(sql, name, async: false, allow_retry: false, materialize_transactions: true)
log(sql, name, async: async) do
with_raw_connection(allow_retry: allow_retry, materialize_transactions: materialize_transactions) do |conn|
sync_timezone_changes(conn)
result = conn.query(sql)
handle_warnings(sql)
result
end
end
end
def exec_stmt_and_free(sql, name, binds, cache_stmt: false, async: false)
sql = transform_query(sql)
check_if_write_query(sql)
mark_transaction_written_if_write(sql)
type_casted_binds = type_casted_binds(binds)
log(sql, name, binds, type_casted_binds, async: async) do
with_raw_connection do |conn|
sync_timezone_changes(conn)
if cache_stmt
stmt = @statements[sql] ||= conn.prepare(sql)
else
stmt = conn.prepare(sql)
end
begin
result = ActiveSupport::Dependencies.interlock.permit_concurrent_loads do
stmt.execute(*type_casted_binds)
end
rescue Mysql2::Error => e
if cache_stmt
@statements.delete(sql)
else
stmt.close
end
raise e
end
ret = yield stmt, result
result.free if result
stmt.close unless cache_stmt
ret
end
end
end
end
end
end
end

@ -1,7 +1,7 @@
# frozen_string_literal: true
require "active_record/connection_adapters/abstract_mysql_adapter"
require "active_record/connection_adapters/mysql/database_statements"
require "active_record/connection_adapters/mysql2/database_statements"
gem "mysql2", "~> 0.5"
require "mysql2"
@ -28,12 +28,12 @@ class Mysql2Adapter < AbstractMysqlAdapter
ADAPTER_NAME = "Mysql2"
include MySQL::DatabaseStatements
include Mysql2::DatabaseStatements
class << self
def new_client(config)
Mysql2::Client.new(config)
rescue Mysql2::Error => error
::Mysql2::Client.new(config)
rescue ::Mysql2::Error => error
if error.error_number == ConnectionAdapters::Mysql2Adapter::ER_BAD_DB_ERROR
raise ActiveRecord::NoDatabaseError.db_error(config[:database])
elsif error.error_number == ConnectionAdapters::Mysql2Adapter::ER_ACCESS_DENIED_ERROR
@ -54,7 +54,7 @@ def initialize(...)
if @config[:flags].kind_of? Array
@config[:flags].push "FOUND_ROWS"
else
@config[:flags] |= Mysql2::Client::FOUND_ROWS
@config[:flags] |= ::Mysql2::Client::FOUND_ROWS
end
@connection_parameters ||= @config
@ -159,9 +159,9 @@ def get_full_version
end
def translate_exception(exception, message:, sql:, binds:)
if exception.is_a?(Mysql2::Error::TimeoutError) && !exception.error_number
if exception.is_a?(::Mysql2::Error::TimeoutError) && !exception.error_number
ActiveRecord::AdapterTimeout.new(message, sql: sql, binds: binds)
elsif exception.is_a?(Mysql2::Error::ConnectionError)
elsif exception.is_a?(::Mysql2::Error::ConnectionError)
if exception.message.match?(/MySQL client is not connected/i)
ActiveRecord::ConnectionNotEstablished.new(exception)
else

@ -4,14 +4,6 @@ module ActiveRecord
module ConnectionAdapters
module Trilogy
module DatabaseStatements
READ_QUERY = AbstractAdapter.build_read_query_regexp(
:desc, :describe, :set, :show, :use
) # :nodoc:
private_constant :READ_QUERY
HIGH_PRECISION_CURRENT_TIMESTAMP = Arel.sql("CURRENT_TIMESTAMP(6)").freeze # :nodoc:
private_constant :HIGH_PRECISION_CURRENT_TIMESTAMP
def select_all(*, **) # :nodoc:
result = nil
with_raw_connection do |conn|
@ -21,21 +13,6 @@ def select_all(*, **) # :nodoc:
result
end
def write_query?(sql) # :nodoc:
!READ_QUERY.match?(sql)
rescue ArgumentError # Invalid encoding
!READ_QUERY.match?(sql.b)
end
def explain(arel, binds = [], options = [])
sql = build_explain_clause(options) + " " + to_sql(arel, binds)
start = Process.clock_gettime(Process::CLOCK_MONOTONIC)
result = internal_exec_query(sql, "EXPLAIN", binds)
elapsed = Process.clock_gettime(Process::CLOCK_MONOTONIC) - start
MySQL::ExplainPrettyPrinter.new.pp(result, elapsed)
end
def internal_exec_query(sql, name = "SQL", binds = [], prepare: false, async: false) # :nodoc:
sql = transform_query(sql)
check_if_write_query(sql)
@ -64,22 +41,6 @@ def exec_delete(sql, name = nil, binds = []) # :nodoc:
alias :exec_update :exec_delete # :nodoc:
def high_precision_current_timestamp
HIGH_PRECISION_CURRENT_TIMESTAMP
end
def build_explain_clause(options = [])
return "EXPLAIN" if options.empty?
explain_clause = "EXPLAIN #{options.join(" ").upcase}"
if analyze_without_explain? && explain_clause.include?("ANALYZE")
explain_clause.sub("EXPLAIN ", "")
else
explain_clause
end
end
private
def raw_execute(sql, name, async: false, allow_retry: false, materialize_transactions: true)
log(sql, name, async: async) do
@ -95,6 +56,43 @@ def raw_execute(sql, name, async: false, allow_retry: false, materialize_transac
def last_inserted_id(result)
result.last_insert_id
end
def sync_timezone_changes(conn)
# Sync any changes since connection last established.
if default_timezone == :local
conn.query_flags |= ::Trilogy::QUERY_FLAGS_LOCAL_TIMEZONE
else
conn.query_flags &= ~::Trilogy::QUERY_FLAGS_LOCAL_TIMEZONE
end
end
def execute_batch(statements, name = nil)
statements = statements.map { |sql| transform_query(sql) }
combine_multi_statements(statements).each do |statement|
with_raw_connection do |conn|
raw_execute(statement, name)
conn.next_result while conn.more_results_exist?
end
end
end
def multi_statements_enabled?
!!@config[:multi_statement]
end
def with_multi_statements
if multi_statements_enabled?
return yield
end
with_raw_connection do |conn|
conn.set_server_option(::Trilogy::SET_SERVER_MULTI_STATEMENTS_ON)
yield
ensure
conn.set_server_option(::Trilogy::SET_SERVER_MULTI_STATEMENTS_OFF)
end
end
end
end
end

@ -182,70 +182,6 @@ def reconnect
connect
end
def sync_timezone_changes(conn)
# Sync any changes since connection last established.
if default_timezone == :local
conn.query_flags |= ::Trilogy::QUERY_FLAGS_LOCAL_TIMEZONE
else
conn.query_flags &= ~::Trilogy::QUERY_FLAGS_LOCAL_TIMEZONE
end
end
def execute_batch(statements, name = nil)
statements = statements.map { |sql| transform_query(sql) }
combine_multi_statements(statements).each do |statement|
with_raw_connection do |conn|
raw_execute(statement, name)
conn.next_result while conn.more_results_exist?
end
end
end
def multi_statements_enabled?
!!@config[:multi_statement]
end
def with_multi_statements
if multi_statements_enabled?
return yield
end
with_raw_connection do |conn|
conn.set_server_option(::Trilogy::SET_SERVER_MULTI_STATEMENTS_ON)
yield
ensure
conn.set_server_option(::Trilogy::SET_SERVER_MULTI_STATEMENTS_OFF)
end
end
def combine_multi_statements(total_sql)
total_sql.each_with_object([]) do |sql, total_sql_chunks|
previous_packet = total_sql_chunks.last
if max_allowed_packet_reached?(sql, previous_packet)
total_sql_chunks << +sql
else
previous_packet << ";\n"
previous_packet << sql
end
end
end
def max_allowed_packet_reached?(current_packet, previous_packet)
if current_packet.bytesize > max_allowed_packet
raise ActiveRecordError,
"Fixtures set is too large #{current_packet.bytesize}. Consider increasing the max_allowed_packet variable."
elsif previous_packet.nil?
true
else
(current_packet.bytesize + previous_packet.bytesize + 2) > max_allowed_packet
end
end
def max_allowed_packet
@max_allowed_packet ||= show_variable("max_allowed_packet")
end
def full_version
schema_cache.database_version.full_version_string
end
@ -265,15 +201,6 @@ def translate_exception(exception, message:, sql:, binds:)
def default_prepared_statements
false
end
def default_insert_value(column)
super unless column.auto_increment?
end
# https://mariadb.com/kb/en/analyze-statement/
def analyze_without_explain?
mariadb? && database_version >= "10.1.0"
end
end
end
end

@ -52,7 +52,7 @@ def render_bind(attr)
end
def build_explain_clause(options = [])
if connection.respond_to?(:build_explain_clause)
if connection.respond_to?(:build_explain_clause, true)
connection.build_explain_clause(options)
else
"EXPLAIN for:"