make mysql stuff thread safe, compiling untested

This commit is contained in:
Richard Holland
2025-02-14 13:55:46 +11:00
parent c98d9b58de
commit 584af6bda0
2 changed files with 330 additions and 178 deletions

View File

@@ -13,12 +13,86 @@
namespace ripple {
struct MySQLDeleter
{
void
operator()(MYSQL* mysql)
{
if (mysql)
{
mysql_close(mysql);
}
}
};
// Thread-local MySQL connection
static thread_local std::unique_ptr<MYSQL, MySQLDeleter> threadLocalMySQL_;
class MySQLDatabase : public SQLiteDatabase
{
private:
Application& app_;
bool const useTxTables_;
std::unique_ptr<MYSQL, decltype(&mysql_close)> mysql_;
// Configuration for creating new connections
struct MySQLConfig
{
std::string host;
std::string user;
std::string pass;
std::string name;
unsigned int port;
};
MySQLConfig config_;
// Initialize a new MySQL connection using stored config
MYSQL*
initializeConnection()
{
MYSQL* mysql = mysql_init(nullptr);
if (!mysql)
{
throw std::runtime_error("Failed to initialize MySQL");
}
if (!mysql_real_connect(
mysql,
config_.host.c_str(),
config_.user.c_str(),
config_.pass.c_str(),
nullptr, // Don't select database in connection
config_.port,
nullptr,
0))
{
auto error = mysql_error(mysql);
mysql_close(mysql);
throw std::runtime_error(
std::string("Failed to connect to MySQL: ") + error);
}
// Select the database
if (mysql_select_db(mysql, config_.name.c_str()))
{
auto error = mysql_error(mysql);
mysql_close(mysql);
throw std::runtime_error(
std::string("Failed to select database: ") + error);
}
return mysql;
}
// Get the thread-local MySQL connection, creating it if necessary
MYSQL*
getConnection()
{
if (!threadLocalMySQL_)
{
threadLocalMySQL_.reset(initializeConnection());
}
return threadLocalMySQL_.get();
}
// Schema creation statements
static constexpr auto CREATE_LEDGERS_TABLE = R"SQL(
@@ -76,73 +150,52 @@ public:
// Then modify the constructor:
MySQLDatabase(Application& app, Config const& config, JobQueue& jobQueue)
: app_(app)
, useTxTables_(config.useTxTables())
, mysql_(mysql_init(nullptr), mysql_close)
: app_(app), useTxTables_(config.useTxTables())
{
if (!mysql_)
throw std::runtime_error("Failed to initialize MySQL");
if (!config.mysql.has_value())
throw std::runtime_error(
"[mysql_settings] stanza missing from config!");
// Read MySQL connection details from config
auto* conn = mysql_real_connect(
mysql_.get(),
config.mysql->host.c_str(),
config.mysql->user.c_str(),
config.mysql->pass.c_str(),
nullptr, // Don't select database in connection
config.mysql->port,
nullptr,
0);
// Store configuration for creating new connections
config_.host = config.mysql->host;
config_.user = config.mysql->user;
config_.pass = config.mysql->pass;
config_.name = config.mysql->name;
config_.port = config.mysql->port;
if (!conn)
throw std::runtime_error(
std::string("Failed to connect to MySQL: ") +
mysql_error(mysql_.get()));
// Initialize first connection and create schema
auto mysql = getConnection();
// Create database if it doesn't exist
std::string create_db =
"CREATE DATABASE IF NOT EXISTS " + config.mysql->name;
std::cout << "create_db: `" << create_db << "`\n";
if (mysql_query(mysql_.get(), create_db.c_str()))
std::string create_db = "CREATE DATABASE IF NOT EXISTS " + config_.name;
if (mysql_query(mysql, create_db.c_str()))
throw std::runtime_error(
std::string("Failed to create database (2): ") +
mysql_error(mysql_.get()));
std::string("Failed to create database: ") +
mysql_error(mysql));
// Select the database
if (mysql_select_db(mysql_.get(), config.mysql->name.c_str()))
throw std::runtime_error(
std::string("Failed to select database: ") +
mysql_error(mysql_.get()));
// Create nodes table first
if (mysql_query(mysql_.get(), CREATE_NODES_TABLE))
// Create schema tables
if (mysql_query(mysql, CREATE_NODES_TABLE))
throw std::runtime_error(
std::string("Failed to create nodes table: ") +
mysql_error(mysql_.get()));
mysql_error(mysql));
// Create schema if not exists
if (mysql_query(mysql_.get(), CREATE_LEDGERS_TABLE))
if (mysql_query(mysql, CREATE_LEDGERS_TABLE))
throw std::runtime_error(
std::string("Failed to create ledgers table: ") +
mysql_error(mysql_.get()));
mysql_error(mysql));
if (useTxTables_)
{
if (mysql_query(mysql_.get(), CREATE_TRANSACTIONS_TABLE))
if (mysql_query(mysql, CREATE_TRANSACTIONS_TABLE))
throw std::runtime_error(
std::string("Failed to create transactions table: ") +
mysql_error(mysql_.get()));
mysql_error(mysql));
if (mysql_query(mysql_.get(), CREATE_ACCOUNT_TRANSACTIONS_TABLE))
if (mysql_query(mysql, CREATE_ACCOUNT_TRANSACTIONS_TABLE))
throw std::runtime_error(
std::string(
"Failed to create account_transactions table: ") +
mysql_error(mysql_.get()));
mysql_error(mysql));
}
}
@@ -186,10 +239,10 @@ public:
<< "account_hash = VALUES(account_hash), "
<< "tx_hash = VALUES(tx_hash)";
if (mysql_query(mysql_.get(), sql.str().c_str()))
if (mysql_query(getConnection(), sql.str().c_str()))
{
JLOG(j.fatal())
<< "Failed to save ledger: " << mysql_error(mysql_.get());
<< "Failed to save ledger: " << mysql_error(getConnection());
return false;
}
@@ -214,10 +267,10 @@ public:
}
// Start a transaction for saving all transactions
if (mysql_query(mysql_.get(), "START TRANSACTION"))
if (mysql_query(getConnection(), "START TRANSACTION"))
{
JLOG(j.fatal()) << "Failed to start transaction: "
<< mysql_error(mysql_.get());
<< mysql_error(getConnection());
return false;
}
@@ -243,7 +296,7 @@ public:
<< "raw_tx = VALUES(raw_tx), "
<< "meta_data = VALUES(meta_data)";
MYSQL_STMT* stmt = mysql_stmt_init(mysql_.get());
MYSQL_STMT* stmt = mysql_stmt_init(getConnection());
if (!stmt)
{
throw std::runtime_error(
@@ -301,9 +354,11 @@ public:
<< "ON DUPLICATE KEY UPDATE "
<< "tx_hash = VALUES(tx_hash)";
if (mysql_query(mysql_.get(), accTxSql.str().c_str()))
if (mysql_query(
getConnection(), accTxSql.str().c_str()))
{
throw std::runtime_error(mysql_error(mysql_.get()));
throw std::runtime_error(
mysql_error(getConnection()));
}
}
@@ -314,15 +369,15 @@ public:
app_.config().NETWORK_ID);
}
if (mysql_query(mysql_.get(), "COMMIT"))
if (mysql_query(getConnection(), "COMMIT"))
{
throw std::runtime_error(mysql_error(mysql_.get()));
throw std::runtime_error(mysql_error(getConnection()));
}
}
catch (std::exception const& e)
{
JLOG(j.fatal()) << "Error saving transactions: " << e.what();
mysql_query(mysql_.get(), "ROLLBACK");
mysql_query(getConnection(), "ROLLBACK");
return false;
}
}
@@ -333,10 +388,10 @@ public:
std::optional<LedgerIndex>
getMinLedgerSeq() override
{
if (mysql_query(mysql_.get(), "SELECT MIN(ledger_seq) FROM ledgers"))
if (mysql_query(getConnection(), "SELECT MIN(ledger_seq) FROM ledgers"))
return std::nullopt;
MYSQL_RES* result = mysql_store_result(mysql_.get());
MYSQL_RES* result = mysql_store_result(getConnection());
if (!result)
return std::nullopt;
@@ -359,10 +414,10 @@ public:
return {};
if (mysql_query(
mysql_.get(), "SELECT MIN(ledger_seq) FROM transactions"))
getConnection(), "SELECT MIN(ledger_seq) FROM transactions"))
return std::nullopt;
MYSQL_RES* result = mysql_store_result(mysql_.get());
MYSQL_RES* result = mysql_store_result(getConnection());
if (!result)
return std::nullopt;
@@ -381,10 +436,10 @@ public:
std::optional<LedgerIndex>
getMaxLedgerSeq() override
{
if (mysql_query(mysql_.get(), "SELECT MAX(ledger_seq) FROM ledgers"))
if (mysql_query(getConnection(), "SELECT MAX(ledger_seq) FROM ledgers"))
return std::nullopt;
MYSQL_RES* result = mysql_store_result(mysql_.get());
MYSQL_RES* result = mysql_store_result(getConnection());
if (!result)
return std::nullopt;
@@ -409,11 +464,11 @@ public:
std::stringstream sql;
sql << "DELETE FROM account_transactions WHERE ledger_seq = "
<< ledgerSeq;
mysql_query(mysql_.get(), sql.str().c_str());
mysql_query(getConnection(), sql.str().c_str());
sql.str("");
sql << "DELETE FROM transactions WHERE ledger_seq = " << ledgerSeq;
mysql_query(mysql_.get(), sql.str().c_str());
mysql_query(getConnection(), sql.str().c_str());
}
void
@@ -424,16 +479,16 @@ public:
std::stringstream sql;
sql << "DELETE FROM account_transactions WHERE ledger_seq < "
<< ledgerSeq;
mysql_query(mysql_.get(), sql.str().c_str());
mysql_query(getConnection(), sql.str().c_str());
sql.str("");
sql << "DELETE FROM transactions WHERE ledger_seq < " << ledgerSeq;
mysql_query(mysql_.get(), sql.str().c_str());
mysql_query(getConnection(), sql.str().c_str());
}
std::stringstream sql;
sql << "DELETE FROM ledgers WHERE ledger_seq < " << ledgerSeq;
mysql_query(mysql_.get(), sql.str().c_str());
mysql_query(getConnection(), sql.str().c_str());
}
void
@@ -445,11 +500,11 @@ public:
std::stringstream sql;
sql << "DELETE FROM account_transactions WHERE ledger_seq < "
<< ledgerSeq;
mysql_query(mysql_.get(), sql.str().c_str());
mysql_query(getConnection(), sql.str().c_str());
sql.str("");
sql << "DELETE FROM transactions WHERE ledger_seq < " << ledgerSeq;
mysql_query(mysql_.get(), sql.str().c_str());
mysql_query(getConnection(), sql.str().c_str());
}
void
@@ -461,7 +516,7 @@ public:
std::stringstream sql;
sql << "DELETE FROM account_transactions WHERE ledger_seq < "
<< ledgerSeq;
mysql_query(mysql_.get(), sql.str().c_str());
mysql_query(getConnection(), sql.str().c_str());
}
std::size_t
@@ -470,10 +525,10 @@ public:
if (!useTxTables_)
return 0;
if (mysql_query(mysql_.get(), "SELECT COUNT(*) FROM transactions"))
if (mysql_query(getConnection(), "SELECT COUNT(*) FROM transactions"))
return 0;
MYSQL_RES* result = mysql_store_result(mysql_.get());
MYSQL_RES* result = mysql_store_result(getConnection());
if (!result)
return 0;
@@ -496,10 +551,10 @@ public:
return 0;
if (mysql_query(
mysql_.get(), "SELECT COUNT(*) FROM account_transactions"))
getConnection(), "SELECT COUNT(*) FROM account_transactions"))
return 0;
MYSQL_RES* result = mysql_store_result(mysql_.get());
MYSQL_RES* result = mysql_store_result(getConnection());
if (!result)
return 0;
@@ -519,12 +574,12 @@ public:
getLedgerCountMinMax() override
{
if (mysql_query(
mysql_.get(),
getConnection(),
"SELECT COUNT(*), MIN(ledger_seq), MAX(ledger_seq) FROM "
"ledgers"))
return {0, 0, 0};
MYSQL_RES* result = mysql_store_result(mysql_.get());
MYSQL_RES* result = mysql_store_result(getConnection());
if (!result)
return {0, 0, 0};
@@ -552,10 +607,10 @@ public:
<< "account_hash, tx_hash FROM ledgers WHERE ledger_seq = "
<< ledgerSeq;
if (mysql_query(mysql_.get(), sql.str().c_str()))
if (mysql_query(getConnection(), sql.str().c_str()))
return std::nullopt;
MYSQL_RES* result = mysql_store_result(mysql_.get());
MYSQL_RES* result = mysql_store_result(getConnection());
if (!result)
return std::nullopt;
@@ -591,10 +646,10 @@ public:
sql << "SELECT ledger_seq FROM ledgers WHERE ledger_seq >= "
<< ledgerFirstIndex << " ORDER BY ledger_seq ASC LIMIT 1";
if (mysql_query(mysql_.get(), sql.str().c_str()))
if (mysql_query(getConnection(), sql.str().c_str()))
return std::nullopt;
MYSQL_RES* result = mysql_store_result(mysql_.get());
MYSQL_RES* result = mysql_store_result(getConnection());
if (!result)
return std::nullopt;
@@ -617,10 +672,10 @@ public:
sql << "SELECT ledger_seq FROM ledgers WHERE ledger_seq >= "
<< ledgerFirstIndex << " ORDER BY ledger_seq DESC LIMIT 1";
if (mysql_query(mysql_.get(), sql.str().c_str()))
if (mysql_query(getConnection(), sql.str().c_str()))
return std::nullopt;
MYSQL_RES* result = mysql_store_result(mysql_.get());
MYSQL_RES* result = mysql_store_result(getConnection());
if (!result)
return std::nullopt;
@@ -643,10 +698,10 @@ public:
sql << "SELECT ledger_seq FROM ledgers WHERE ledger_hash = '"
<< strHex(ledgerHash) << "'";
if (mysql_query(mysql_.get(), sql.str().c_str()))
if (mysql_query(getConnection(), sql.str().c_str()))
return std::nullopt;
MYSQL_RES* result = mysql_store_result(mysql_.get());
MYSQL_RES* result = mysql_store_result(getConnection());
if (!result)
return std::nullopt;
@@ -669,10 +724,10 @@ public:
sql << "SELECT ledger_hash FROM ledgers WHERE ledger_seq = "
<< ledgerIndex;
if (mysql_query(mysql_.get(), sql.str().c_str()))
if (mysql_query(getConnection(), sql.str().c_str()))
return uint256();
MYSQL_RES* result = mysql_store_result(mysql_.get());
MYSQL_RES* result = mysql_store_result(getConnection());
if (!result)
return uint256();
@@ -696,10 +751,10 @@ public:
"= "
<< ledgerIndex;
if (mysql_query(mysql_.get(), sql.str().c_str()))
if (mysql_query(getConnection(), sql.str().c_str()))
return std::nullopt;
MYSQL_RES* result = mysql_store_result(mysql_.get());
MYSQL_RES* result = mysql_store_result(getConnection());
if (!result)
return std::nullopt;
@@ -724,10 +779,10 @@ public:
<< "WHERE ledger_seq BETWEEN " << minSeq << " AND " << maxSeq
<< " ORDER BY ledger_seq";
if (mysql_query(mysql_.get(), sql.str().c_str()))
if (mysql_query(getConnection(), sql.str().c_str()))
return result;
MYSQL_RES* sqlResult = mysql_store_result(mysql_.get());
MYSQL_RES* sqlResult = mysql_store_result(getConnection());
if (!sqlResult)
return result;
@@ -751,11 +806,11 @@ public:
return {};
if (mysql_query(
mysql_.get(),
getConnection(),
"SELECT MIN(ledger_seq) FROM account_transactions"))
return std::nullopt;
MYSQL_RES* result = mysql_store_result(mysql_.get());
MYSQL_RES* result = mysql_store_result(getConnection());
if (!result)
return std::nullopt;
@@ -775,12 +830,12 @@ public:
getNewestLedgerInfo() override
{
if (mysql_query(
mysql_.get(),
getConnection(),
"SELECT ledger_seq FROM ledgers ORDER BY ledger_seq DESC LIMIT "
"1"))
return std::nullopt;
MYSQL_RES* result = mysql_store_result(mysql_.get());
MYSQL_RES* result = mysql_store_result(getConnection());
if (!result)
return std::nullopt;
@@ -815,10 +870,10 @@ public:
<< range->last();
}
if (mysql_query(mysql_.get(), sql.str().c_str()))
if (mysql_query(getConnection(), sql.str().c_str()))
return TxSearched::unknown;
MYSQL_RES* result = mysql_store_result(mysql_.get());
MYSQL_RES* result = mysql_store_result(getConnection());
if (!result)
return TxSearched::unknown;
@@ -832,10 +887,10 @@ public:
sql << "SELECT COUNT(*) FROM ledgers WHERE ledger_seq BETWEEN "
<< range->first() << " AND " << range->last();
if (mysql_query(mysql_.get(), sql.str().c_str()))
if (mysql_query(getConnection(), sql.str().c_str()))
return TxSearched::unknown;
result = mysql_store_result(mysql_.get());
result = mysql_store_result(getConnection());
if (!result)
return TxSearched::unknown;
@@ -921,10 +976,10 @@ public:
<< options.maxLedger << " ORDER BY at.ledger_seq, at.tx_seq"
<< " LIMIT " << (options.limit + 1);
if (mysql_query(mysql_.get(), sql.str().c_str()))
if (mysql_query(getConnection(), sql.str().c_str()))
return {};
MYSQL_RES* result = mysql_store_result(mysql_.get());
MYSQL_RES* result = mysql_store_result(getConnection());
if (!result)
return {};
@@ -991,10 +1046,10 @@ public:
<< " ORDER BY at.ledger_seq DESC, at.tx_seq DESC"
<< " LIMIT " << (options.limit + 1);
if (mysql_query(mysql_.get(), sql.str().c_str()))
if (mysql_query(getConnection(), sql.str().c_str()))
return {};
MYSQL_RES* result = mysql_store_result(mysql_.get());
MYSQL_RES* result = mysql_store_result(getConnection());
if (!result)
return {};
@@ -1051,10 +1106,10 @@ public:
<< "ORDER BY t.ledger_seq DESC, t.tx_seq DESC "
<< "LIMIT 20 OFFSET " << startIndex;
if (mysql_query(mysql_.get(), sql.str().c_str()))
if (mysql_query(getConnection(), sql.str().c_str()))
return result;
MYSQL_RES* sqlResult = mysql_store_result(mysql_.get());
MYSQL_RES* sqlResult = mysql_store_result(getConnection());
if (!sqlResult)
return result;
@@ -1119,10 +1174,10 @@ public:
sql << " OFFSET " << options.offset;
}
if (mysql_query(mysql_.get(), sql.str().c_str()))
if (mysql_query(getConnection(), sql.str().c_str()))
return result;
MYSQL_RES* sqlResult = mysql_store_result(mysql_.get());
MYSQL_RES* sqlResult = mysql_store_result(getConnection());
if (!sqlResult)
return result;
@@ -1186,10 +1241,10 @@ public:
sql << " OFFSET " << options.offset;
}
if (mysql_query(mysql_.get(), sql.str().c_str()))
if (mysql_query(getConnection(), sql.str().c_str()))
return result;
MYSQL_RES* sqlResult = mysql_store_result(mysql_.get());
MYSQL_RES* sqlResult = mysql_store_result(getConnection());
if (!sqlResult)
return result;
@@ -1253,10 +1308,10 @@ public:
sql << " OFFSET " << options.offset;
}
if (mysql_query(mysql_.get(), sql.str().c_str()))
if (mysql_query(getConnection(), sql.str().c_str()))
return result;
MYSQL_RES* sqlResult = mysql_store_result(mysql_.get());
MYSQL_RES* sqlResult = mysql_store_result(getConnection());
if (!sqlResult)
return result;
@@ -1303,10 +1358,10 @@ public:
sql << " OFFSET " << options.offset;
}
if (mysql_query(mysql_.get(), sql.str().c_str()))
if (mysql_query(getConnection(), sql.str().c_str()))
return result;
MYSQL_RES* sqlResult = mysql_store_result(mysql_.get());
MYSQL_RES* sqlResult = mysql_store_result(getConnection());
if (!sqlResult)
return result;
@@ -1355,10 +1410,10 @@ public:
sql << " ORDER BY at.ledger_seq ASC, at.tx_seq ASC ";
sql << "LIMIT " << (options.limit + 1);
if (mysql_query(mysql_.get(), sql.str().c_str()))
if (mysql_query(getConnection(), sql.str().c_str()))
return {};
MYSQL_RES* sqlResult = mysql_store_result(mysql_.get());
MYSQL_RES* sqlResult = mysql_store_result(getConnection());
if (!sqlResult)
return {};
@@ -1420,10 +1475,10 @@ public:
sql << " ORDER BY at.ledger_seq DESC, at.tx_seq DESC ";
sql << "LIMIT " << (options.limit + 1);
if (mysql_query(mysql_.get(), sql.str().c_str()))
if (mysql_query(getConnection(), sql.str().c_str()))
return {};
MYSQL_RES* sqlResult = mysql_store_result(mysql_.get());
MYSQL_RES* sqlResult = mysql_store_result(getConnection());
if (!sqlResult)
return {};
@@ -1467,13 +1522,13 @@ public:
// Get ledger table size
if (!mysql_query(
mysql_.get(),
getConnection(),
"SELECT ROUND(SUM(data_length + index_length) / 1024) "
"FROM information_schema.tables "
"WHERE table_schema = DATABASE() "
"AND table_name = 'ledgers'"))
{
MYSQL_RES* result = mysql_store_result(mysql_.get());
MYSQL_RES* result = mysql_store_result(getConnection());
if (result)
{
MYSQL_ROW row = mysql_fetch_row(result);
@@ -1487,14 +1542,14 @@ public:
if (useTxTables_)
{
if (!mysql_query(
mysql_.get(),
getConnection(),
"SELECT ROUND(SUM(data_length + index_length) / 1024) "
"FROM information_schema.tables "
"WHERE table_schema = DATABASE() "
"AND (table_name = 'transactions' "
"OR table_name = 'account_transactions')"))
{
MYSQL_RES* result = mysql_store_result(mysql_.get());
MYSQL_RES* result = mysql_store_result(getConnection());
if (result)
{
MYSQL_ROW row = mysql_fetch_row(result);
@@ -1515,13 +1570,13 @@ public:
std::uint32_t total = 0;
if (!mysql_query(
mysql_.get(),
getConnection(),
"SELECT ROUND(SUM(data_length + index_length) / 1024) "
"FROM information_schema.tables "
"WHERE table_schema = DATABASE() "
"AND table_name = 'ledgers'"))
{
MYSQL_RES* result = mysql_store_result(mysql_.get());
MYSQL_RES* result = mysql_store_result(getConnection());
if (result)
{
MYSQL_ROW row = mysql_fetch_row(result);
@@ -1543,14 +1598,14 @@ public:
std::uint32_t total = 0;
if (!mysql_query(
mysql_.get(),
getConnection(),
"SELECT ROUND(SUM(data_length + index_length) / 1024) "
"FROM information_schema.tables "
"WHERE table_schema = DATABASE() "
"AND (table_name = 'transactions' "
"OR table_name = 'account_transactions')"))
{
MYSQL_RES* result = mysql_store_result(mysql_.get());
MYSQL_RES* result = mysql_store_result(getConnection());
if (result)
{
MYSQL_ROW row = mysql_fetch_row(result);

View File

@@ -12,52 +12,29 @@
#include <memory>
#include <mysql/mysql.h>
#include <sstream>
#include <thread>
namespace ripple {
namespace NodeStore {
class MySQLBackend : public Backend
class MySQLConnection
{
private:
std::string name_;
beast::Journal journal_;
bool isOpen_{false};
std::unique_ptr<MYSQL, decltype(&mysql_close)> mysql_;
Config const& config_;
static constexpr auto CREATE_DATABASE = R"SQL(
CREATE DATABASE IF NOT EXISTS `%s`
CHARACTER SET utf8mb4
COLLATE utf8mb4_unicode_ci
)SQL";
static constexpr auto CREATE_NODES_TABLE = R"SQL(
CREATE TABLE IF NOT EXISTS nodes (
hash BINARY(32) PRIMARY KEY,
data MEDIUMBLOB NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
) ENGINE=InnoDB
)SQL";
std::string const& dbName_;
beast::Journal journal_;
public:
MySQLBackend(
std::size_t keyBytes,
Section const& keyValues,
MySQLConnection(
Config const& config,
std::string const& dbName,
beast::Journal journal)
: name_(get(keyValues, "path", "nodestore"))
: mysql_(mysql_init(nullptr), mysql_close)
, config_(config)
, dbName_(dbName)
, journal_(journal)
, mysql_(mysql_init(nullptr), mysql_close)
, config_(keyValues.getParent())
{
// mysql names are limited to alphanumeric
name_.erase(
std::remove_if(
name_.begin(),
name_.end(),
[](char c) { return !std::isalnum(c); }),
name_.end());
if (!mysql_)
Throw<std::runtime_error>("Failed to initialize MySQL");
@@ -70,7 +47,7 @@ public:
config_.mysql->host.c_str(),
config_.mysql->user.c_str(),
config_.mysql->pass.c_str(),
nullptr,
dbName_.c_str(),
config_.mysql->port,
nullptr,
0);
@@ -86,19 +63,117 @@ public:
mysql_options(mysql_.get(), MYSQL_OPT_RECONNECT, &reconnect);
}
MYSQL*
get()
{
return mysql_.get();
}
bool
ensureConnection()
{
if (!mysql_ || !mysql_.get() || mysql_ping(mysql_.get()) != 0)
{
JLOG(journal_.error())
<< "MySQL connection lost, attempting reconnect";
try
{
mysql_.reset(mysql_init(nullptr));
auto* conn = mysql_real_connect(
mysql_.get(),
config_.mysql->host.c_str(),
config_.mysql->user.c_str(),
config_.mysql->pass.c_str(),
dbName_.c_str(),
config_.mysql->port,
nullptr,
0);
if (!conn)
return false;
uint8_t const reconnect = 1;
mysql_options(mysql_.get(), MYSQL_OPT_RECONNECT, &reconnect);
return true;
}
catch (...)
{
return false;
}
}
return true;
}
};
class MySQLBackend : public Backend
{
private:
std::string name_;
beast::Journal journal_;
bool isOpen_{false};
Config const& config_;
// Thread-local MySQL connection
static thread_local std::unique_ptr<MySQLConnection> threadConnection_;
static constexpr auto CREATE_DATABASE = R"SQL(
CREATE DATABASE IF NOT EXISTS `%s`
CHARACTER SET utf8mb4
COLLATE utf8mb4_unicode_ci
)SQL";
static constexpr auto CREATE_NODES_TABLE = R"SQL(
CREATE TABLE IF NOT EXISTS nodes (
hash BINARY(32) PRIMARY KEY,
data MEDIUMBLOB NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
) ENGINE=InnoDB
)SQL";
MySQLConnection*
getConnection()
{
if (!threadConnection_)
{
threadConnection_ =
std::make_unique<MySQLConnection>(config_, name_, journal_);
}
return threadConnection_.get();
}
public:
MySQLBackend(
std::size_t keyBytes,
Section const& keyValues,
beast::Journal journal)
: name_(get(keyValues, "path", "nodestore"))
, journal_(journal)
, config_(keyValues.getParent())
{
// mysql names are limited to alphanumeric
name_.erase(
std::remove_if(
name_.begin(),
name_.end(),
[](char c) { return !std::isalnum(c); }),
name_.end());
}
void
createDatabase()
{
auto conn = std::make_unique<MySQLConnection>(config_, "", journal_);
std::string query(1024, '\0');
int length =
snprintf(&query[0], query.size(), CREATE_DATABASE, name_.c_str());
query.resize(length);
if (mysql_query(mysql_.get(), query.c_str()))
if (mysql_query(conn->get(), query.c_str()))
{
Throw<std::runtime_error>(
std::string("Failed to create database: ") +
mysql_error(mysql_.get()) + " (1)");
mysql_error(conn->get()));
}
}
@@ -119,27 +194,23 @@ public:
if (isOpen_)
Throw<std::runtime_error>("already open");
// Ensure database is selected
if (!config_.mysql.has_value())
throw std::runtime_error(
"[mysql_settings] stanza missing from config!");
createDatabase();
if (mysql_select_db(mysql_.get(), name_.c_str()))
{
Throw<std::runtime_error>(
std::string("Failed to select database: ") +
mysql_error(mysql_.get()));
}
auto* conn = getConnection();
if (!conn->ensureConnection())
Throw<std::runtime_error>("Failed to establish MySQL connection");
if (createIfMissing)
{
if (mysql_query(mysql_.get(), CREATE_NODES_TABLE))
if (mysql_query(conn->get(), CREATE_NODES_TABLE))
{
Throw<std::runtime_error>(
std::string("Failed to create nodes table: ") +
mysql_error(mysql_.get()));
mysql_error(conn->get()));
}
}
@@ -155,6 +226,7 @@ public:
void
close() override
{
threadConnection_.reset();
isOpen_ = false;
}
@@ -164,9 +236,13 @@ public:
if (!isOpen_)
return notFound;
auto* conn = getConnection();
if (!conn->ensureConnection())
return dataCorrupt;
uint256 const hash(uint256::fromVoid(key));
MYSQL_STMT* stmt = mysql_stmt_init(mysql_.get());
MYSQL_STMT* stmt = mysql_stmt_init(conn->get());
if (!stmt)
return dataCorrupt;
@@ -262,7 +338,11 @@ public:
if (!isOpen_)
return {results, notFound};
if (mysql_query(mysql_.get(), "START TRANSACTION"))
auto* conn = getConnection();
if (!conn->ensureConnection())
return {results, dataCorrupt};
if (mysql_query(conn->get(), "START TRANSACTION"))
return {results, dataCorrupt};
try
@@ -274,14 +354,14 @@ public:
results.push_back(status == ok ? nObj : nullptr);
}
if (mysql_query(mysql_.get(), "COMMIT"))
if (mysql_query(conn->get(), "COMMIT"))
return {results, dataCorrupt};
return {results, ok};
}
catch (...)
{
mysql_query(mysql_.get(), "ROLLBACK");
mysql_query(conn->get(), "ROLLBACK");
throw;
}
}
@@ -292,12 +372,16 @@ public:
if (!isOpen_ || !object)
return;
auto* conn = getConnection();
if (!conn->ensureConnection())
return;
EncodedBlob encoded(object);
nudb::detail::buffer compressed;
auto const result = nodeobject_compress(
encoded.getData(), encoded.getSize(), compressed);
MYSQL_STMT* stmt = mysql_stmt_init(mysql_.get());
MYSQL_STMT* stmt = mysql_stmt_init(conn->get());
if (!stmt)
return;
@@ -307,6 +391,8 @@ public:
if (mysql_stmt_prepare(stmt, sql.c_str(), sql.length()))
{
JLOG(journal_.error()) << "Failed to prepare MySQL statement: "
<< mysql_stmt_error(stmt);
mysql_stmt_close(stmt);
return;
}
@@ -346,7 +432,11 @@ public:
if (!isOpen_)
return;
if (mysql_query(mysql_.get(), "START TRANSACTION"))
auto* conn = getConnection();
if (!conn->ensureConnection())
return;
if (mysql_query(conn->get(), "START TRANSACTION"))
return;
try
@@ -354,12 +444,12 @@ public:
for (auto const& e : batch)
store(e);
if (mysql_query(mysql_.get(), "COMMIT"))
mysql_query(mysql_.get(), "ROLLBACK");
if (mysql_query(conn->get(), "COMMIT"))
mysql_query(conn->get(), "ROLLBACK");
}
catch (...)
{
mysql_query(mysql_.get(), "ROLLBACK");
mysql_query(conn->get(), "ROLLBACK");
throw;
}
}
@@ -375,12 +465,16 @@ public:
if (!isOpen_)
return;
auto* conn = getConnection();
if (!conn->ensureConnection())
return;
if (mysql_query(
mysql_.get(),
conn->get(),
"SELECT hash, data FROM nodes ORDER BY created_at"))
return;
MYSQL_RES* result = mysql_store_result(mysql_.get());
MYSQL_RES* result = mysql_store_result(conn->get());
if (!result)
return;
@@ -424,6 +518,9 @@ public:
}
};
// Initialize the thread_local connection
thread_local std::unique_ptr<MySQLConnection> MySQLBackend::threadConnection_;
class MySQLFactory : public Factory
{
public: