refactor: Use mutex from utils (#1851)

Fixes #1359.
This commit is contained in:
Sergey Kuznetsov
2025-01-27 15:28:00 +00:00
committed by GitHub
parent 6ef6ca9e65
commit 540e938223
11 changed files with 112 additions and 107 deletions

View File

@@ -23,6 +23,7 @@
#include "data/cassandra/Handle.hpp"
#include "data/cassandra/Types.hpp"
#include "data/cassandra/impl/RetryPolicy.hpp"
#include "util/Mutex.hpp"
#include "util/log/Logger.hpp"
#include <boost/asio.hpp>
@@ -64,8 +65,8 @@ class AsyncExecutor : public std::enable_shared_from_this<AsyncExecutor<Statemen
RetryCallbackType onRetry_;
// does not exist during initial construction, hence optional
std::optional<FutureWithCallbackType> future_;
std::mutex mtx_;
using OptionalFuture = std::optional<FutureWithCallbackType>;
util::Mutex<OptionalFuture> future_;
public:
/**
@@ -127,8 +128,8 @@ private:
self = nullptr; // explicitly decrement refcount
};
std::scoped_lock const lck{mtx_};
future_.emplace(handle.asyncExecute(data_, std::move(handler)));
auto future = future_.template lock<std::scoped_lock>();
future->emplace(handle.asyncExecute(data_, std::move(handler)));
}
};

View File

@@ -19,6 +19,8 @@
#pragma once
#include "util/Mutex.hpp"
#include <boost/signals2.hpp>
#include <boost/signals2/connection.hpp>
#include <boost/signals2/variadic_signal.hpp>
@@ -45,8 +47,8 @@ class TrackableSignal {
// map of connection and signal connection, key is the pointer of the connection object
// allow disconnect to be called in the destructor of the connection
std::unordered_map<ConnectionPtr, boost::signals2::connection> connections_;
mutable std::mutex mutex_;
using ConnectionsMap = std::unordered_map<ConnectionPtr, boost::signals2::connection>;
util::Mutex<ConnectionsMap> connections_;
using SignalType = boost::signals2::signal<void(Args...)>;
SignalType signal_;
@@ -64,8 +66,8 @@ public:
bool
connectTrackableSlot(ConnectionSharedPtr const& trackable, std::function<void(Args...)> slot)
{
std::scoped_lock const lk(mutex_);
if (connections_.contains(trackable.get())) {
auto connections = connections_.template lock<std::scoped_lock>();
if (connections->contains(trackable.get())) {
return false;
}
@@ -73,7 +75,7 @@ public:
// the trackable's destructor. However, the trackable can not be destroied when the slot is being called
// either. track_foreign will hold a weak_ptr to the connection, which makes sure the connection is valid when
// the slot is called.
connections_.emplace(
connections->emplace(
trackable.get(), signal_.connect(typename SignalType::slot_type(slot).track_foreign(trackable))
);
return true;
@@ -89,10 +91,9 @@ public:
bool
disconnect(ConnectionPtr trackablePtr)
{
std::scoped_lock const lk(mutex_);
if (connections_.contains(trackablePtr)) {
connections_[trackablePtr].disconnect();
connections_.erase(trackablePtr);
if (auto connections = connections_.template lock<std::scoped_lock>(); connections->contains(trackablePtr)) {
connections->operator[](trackablePtr).disconnect();
connections->erase(trackablePtr);
return true;
}
return false;
@@ -115,8 +116,7 @@ public:
std::size_t
count() const
{
std::scoped_lock const lk(mutex_);
return connections_.size();
return connections_.template lock<std::scoped_lock>()->size();
}
};
} // namespace feed::impl

View File

@@ -20,6 +20,7 @@
#pragma once
#include "feed/impl/TrackableSignal.hpp"
#include "util/Mutex.hpp"
#include <boost/signals2.hpp>
@@ -49,8 +50,8 @@ class TrackableSignalMap {
using ConnectionPtr = Session*;
using ConnectionSharedPtr = std::shared_ptr<Session>;
mutable std::mutex mutex_;
std::unordered_map<Key, TrackableSignal<Session, Args...>> signalsMap_;
using SignalsMap = std::unordered_map<Key, TrackableSignal<Session, Args...>>;
util::Mutex<SignalsMap> signalsMap_;
public:
/**
@@ -66,8 +67,8 @@ public:
bool
connectTrackableSlot(ConnectionSharedPtr const& trackable, Key const& key, std::function<void(Args...)> slot)
{
std::scoped_lock const lk(mutex_);
return signalsMap_[key].connectTrackableSlot(trackable, slot);
auto map = signalsMap_.template lock<std::scoped_lock>();
return map->operator[](key).connectTrackableSlot(trackable, slot);
}
/**
@@ -80,14 +81,14 @@ public:
bool
disconnect(ConnectionPtr trackablePtr, Key const& key)
{
std::scoped_lock const lk(mutex_);
if (!signalsMap_.contains(key))
auto map = signalsMap_.template lock<std::scoped_lock>();
if (!map->contains(key))
return false;
auto const disconnected = signalsMap_[key].disconnect(trackablePtr);
auto const disconnected = map->operator[](key).disconnect(trackablePtr);
// clean the map if there is no connection left.
if (disconnected && signalsMap_[key].count() == 0)
signalsMap_.erase(key);
if (disconnected && map->operator[](key).count() == 0)
map->erase(key);
return disconnected;
}
@@ -101,9 +102,9 @@ public:
void
emit(Key const& key, Args const&... args)
{
std::scoped_lock const lk(mutex_);
if (signalsMap_.contains(key))
signalsMap_[key].emit(args...);
auto map = signalsMap_.template lock<std::scoped_lock>();
if (map->contains(key))
map->operator[](key).emit(args...);
}
};
} // namespace feed::impl

View File

@@ -34,7 +34,7 @@ class Mutex;
* @tparam LockType type of lock
* @tparam MutexType type of mutex
*/
template <typename ProtectedDataType, template <typename> typename LockType, typename MutexType>
template <typename ProtectedDataType, template <typename...> typename LockType, typename MutexType>
class Lock {
LockType<MutexType> lock_;
ProtectedDataType& data_;
@@ -129,7 +129,7 @@ public:
* @tparam LockType The type of lock to use
* @return A lock on the mutex and a reference to the protected data
*/
template <template <typename> typename LockType = std::lock_guard>
template <template <typename...> typename LockType = std::lock_guard>
Lock<ProtectedDataType const, LockType, MutexType>
lock() const
{
@@ -142,7 +142,7 @@ public:
* @tparam LockType The type of lock to use
* @return A lock on the mutex and a reference to the protected data
*/
template <template <typename> typename LockType = std::lock_guard>
template <template <typename...> typename LockType = std::lock_guard>
Lock<ProtectedDataType, LockType, MutexType>
lock()
{

View File

@@ -21,6 +21,7 @@
#include "util/Assert.hpp"
#include "util/Concepts.hpp"
#include "util/Mutex.hpp"
#include "util/prometheus/OStream.hpp"
#include <cstdint>
@@ -61,28 +62,30 @@ public:
void
setBuckets(std::vector<ValueType> const& bounds)
{
std::scoped_lock const lock{*mutex_};
ASSERT(buckets_.empty(), "Buckets can be set only once.");
buckets_.reserve(bounds.size());
auto data = data_->template lock<std::scoped_lock>();
ASSERT(data->buckets.empty(), "Buckets can be set only once.");
data->buckets.reserve(bounds.size());
for (auto const& bound : bounds) {
buckets_.emplace_back(bound);
data->buckets.emplace_back(bound);
}
}
void
observe(ValueType const value)
{
auto const bucket =
std::lower_bound(buckets_.begin(), buckets_.end(), value, [](Bucket const& bucket, ValueType const& value) {
return bucket.upperBound < value;
});
std::scoped_lock const lock{*mutex_};
if (bucket != buckets_.end()) {
auto data = data_->template lock<std::scoped_lock>();
auto const bucket = std::lower_bound(
data->buckets.begin(),
data->buckets.end(),
value,
[](Bucket const& bucket, ValueType const& value) { return bucket.upperBound < value; }
);
if (bucket != data->buckets.end()) {
++bucket->count;
} else {
++lastBucket_.count;
++data->lastBucket.count;
}
sum_ += value;
data->sum += value;
}
void
@@ -98,15 +101,15 @@ public:
labelsString.back() = ',';
}
std::scoped_lock const lock{*mutex_};
auto data = data_->template lock<std::scoped_lock>();
std::uint64_t cumulativeCount = 0;
for (auto const& bucket : buckets_) {
for (auto const& bucket : data->buckets) {
cumulativeCount += bucket.count;
stream << name << "_bucket" << labelsString << "le=\"" << bucket.upperBound << "\"} " << cumulativeCount
<< '\n';
}
cumulativeCount += lastBucket_.count;
cumulativeCount += data->lastBucket.count;
stream << name << "_bucket" << labelsString << "le=\"+Inf\"} " << cumulativeCount << '\n';
if (labelsString.size() == 1) {
@@ -114,7 +117,7 @@ public:
} else {
labelsString.back() = '}';
}
stream << name << "_sum" << labelsString << " " << sum_ << '\n';
stream << name << "_sum" << labelsString << " " << data->sum << '\n';
stream << name << "_count" << labelsString << " " << cumulativeCount << '\n';
}
@@ -128,10 +131,12 @@ private:
std::uint64_t count = 0;
};
std::vector<Bucket> buckets_;
Bucket lastBucket_{std::numeric_limits<ValueType>::max()};
ValueType sum_ = 0;
mutable std::unique_ptr<std::mutex> mutex_ = std::make_unique<std::mutex>();
struct Data {
std::vector<Bucket> buckets;
Bucket lastBucket{std::numeric_limits<ValueType>::max()};
ValueType sum = 0;
};
std::unique_ptr<util::Mutex<Data>> data_ = std::make_unique<util::Mutex<Data>>();
};
} // namespace util::prometheus::impl

View File

@@ -58,17 +58,17 @@ DOSGuard::isOk(std::string const& ip) const noexcept
return true;
{
std::scoped_lock const lck(mtx_);
if (ipState_.find(ip) != ipState_.end()) {
auto [transferedByte, requests] = ipState_.at(ip);
if (transferedByte > maxFetches_ || requests > maxRequestCount_) {
auto lock = mtx_.lock<std::scoped_lock>();
if (lock->ipState.find(ip) != lock->ipState.end()) {
auto [transferredByte, requests] = lock->ipState.at(ip);
if (transferredByte > maxFetches_ || requests > maxRequestCount_) {
LOG(log_.warn()) << "Dosguard: Client surpassed the rate limit. ip = " << ip
<< " Transfered Byte: " << transferedByte << "; Requests: " << requests;
<< " Transfered Byte: " << transferredByte << "; Requests: " << requests;
return false;
}
}
auto it = ipConnCount_.find(ip);
if (it != ipConnCount_.end()) {
auto it = lock->ipConnCount.find(ip);
if (it != lock->ipConnCount.end()) {
if (it->second > maxConnCount_) {
LOG(log_.warn()) << "Dosguard: Client surpassed the rate limit. ip = " << ip
<< " Concurrent connection: " << it->second;
@@ -84,8 +84,8 @@ DOSGuard::increment(std::string const& ip) noexcept
{
if (whitelistHandler_.get().isWhiteListed(ip))
return;
std::scoped_lock const lck{mtx_};
ipConnCount_[ip]++;
auto lock = mtx_.lock<std::scoped_lock>();
lock->ipConnCount[ip]++;
}
void
@@ -93,11 +93,11 @@ DOSGuard::decrement(std::string const& ip) noexcept
{
if (whitelistHandler_.get().isWhiteListed(ip))
return;
std::scoped_lock const lck{mtx_};
ASSERT(ipConnCount_[ip] > 0, "Connection count for ip {} can't be 0", ip);
ipConnCount_[ip]--;
if (ipConnCount_[ip] == 0)
ipConnCount_.erase(ip);
auto lock = mtx_.lock<std::scoped_lock>();
ASSERT(lock->ipConnCount[ip] > 0, "Connection count for ip {} can't be 0", ip);
lock->ipConnCount[ip]--;
if (lock->ipConnCount[ip] == 0)
lock->ipConnCount.erase(ip);
}
[[maybe_unused]] bool
@@ -107,8 +107,8 @@ DOSGuard::add(std::string const& ip, uint32_t numObjects) noexcept
return true;
{
std::scoped_lock const lck(mtx_);
ipState_[ip].transferedByte += numObjects;
auto lock = mtx_.lock<std::scoped_lock>();
lock->ipState[ip].transferedByte += numObjects;
}
return isOk(ip);
@@ -121,8 +121,8 @@ DOSGuard::request(std::string const& ip) noexcept
return true;
{
std::scoped_lock const lck(mtx_);
ipState_[ip].requestsCount++;
auto lock = mtx_.lock<std::scoped_lock>();
lock->ipState[ip].requestsCount++;
}
return isOk(ip);
@@ -131,8 +131,8 @@ DOSGuard::request(std::string const& ip) noexcept
void
DOSGuard::clear() noexcept
{
std::scoped_lock const lck(mtx_);
ipState_.clear();
auto lock = mtx_.lock<std::scoped_lock>();
lock->ipState.clear();
}
[[nodiscard]] std::unordered_set<std::string>

View File

@@ -19,6 +19,7 @@
#pragma once
#include "util/Mutex.hpp"
#include "util/log/Logger.hpp"
#include "util/newconfig/ConfigDefinition.hpp"
#include "web/dosguard/DOSGuardInterface.hpp"
@@ -30,7 +31,6 @@
#include <cstdint>
#include <functional>
#include <mutex>
#include <string>
#include <string_view>
#include <unordered_map>
@@ -52,9 +52,12 @@ class DOSGuard : public DOSGuardInterface {
std::uint32_t requestsCount = 0; /**< Accumulated served requests count */
};
mutable std::mutex mtx_;
std::unordered_map<std::string, ClientState> ipState_;
std::unordered_map<std::string, std::uint32_t> ipConnCount_;
struct State {
std::unordered_map<std::string, ClientState> ipState;
std::unordered_map<std::string, std::uint32_t> ipConnCount;
};
util::Mutex<State> mtx_;
std::reference_wrapper<WhitelistHandlerInterface const> whitelistHandler_;
std::uint32_t const maxFetches_;

View File

@@ -20,6 +20,7 @@
#pragma once
#include "data/Types.hpp"
#include "util/Mutex.hpp"
#include <xrpl/basics/Blob.h>
#include <xrpl/basics/base_uint.h>
@@ -27,7 +28,6 @@
#include <cstddef>
#include <cstdint>
#include <map>
#include <mutex>
#include <optional>
#include <thread>
#include <vector>
@@ -57,14 +57,14 @@ struct DiffProvider {
nextKey(std::size_t keysSize)
{
// mock the result from doFetchSuccessorKey, be aware this function will be called from multiple threads
std::lock_guard<std::mutex> const guard(keysMutex_);
threadKeysMap_[std::this_thread::get_id()]++;
auto keysMap = threadKeysMap_.lock();
keysMap->operator[](std::this_thread::get_id())++;
if (threadKeysMap_[std::this_thread::get_id()] == keysSize - 1) {
if (keysMap->operator[](std::this_thread::get_id()) == keysSize - 1) {
return data::kLAST_KEY;
}
if (threadKeysMap_[std::this_thread::get_id()] == keysSize) {
threadKeysMap_[std::this_thread::get_id()] = 0;
if (keysMap->operator[](std::this_thread::get_id()) == keysSize) {
keysMap->operator[](std::this_thread::get_id()) = 0;
return std::nullopt;
}
@@ -72,6 +72,5 @@ struct DiffProvider {
}
private:
std::mutex keysMutex_;
std::map<std::thread::id, uint32_t> threadKeysMap_;
util::Mutex<std::map<std::thread::id, uint32_t>> threadKeysMap_;
};

View File

@@ -22,6 +22,7 @@
#include "data/DBHelpers.hpp"
#include "migration/cassandra/impl/TransactionsAdapter.hpp"
#include "migration/cassandra/impl/Types.hpp"
#include "util/Mutex.hpp"
#include "util/newconfig/ObjectView.hpp"
#include <xrpl/basics/base_uint.h>
@@ -31,7 +32,6 @@
#include <cstdint>
#include <memory>
#include <mutex>
#include <string>
#include <unordered_set>
@@ -47,23 +47,20 @@ ExampleTransactionsMigrator::runMigration(
auto const jobsFullScan = config.get<std::uint32_t>("full_scan_jobs");
auto const cursorPerJobsFullScan = config.get<std::uint32_t>("cursors_per_job");
std::unordered_set<std::string> hashSet;
std::mutex mtx; // protect hashSet
migration::cassandra::impl::TransactionsScanner scaner(
using HashSet = std::unordered_set<std::string>;
util::Mutex<HashSet> hashSet;
migration::cassandra::impl::TransactionsScanner scanner(
{.ctxThreadsNum = ctxFullScanThreads, .jobsNum = jobsFullScan, .cursorsPerJob = cursorPerJobsFullScan},
migration::cassandra::impl::TransactionsAdapter(
backend,
[&](ripple::STTx const& tx, ripple::TxMeta const&) {
{
std::lock_guard<std::mutex> const lock(mtx);
hashSet.insert(ripple::to_string(tx.getTransactionID()));
}
hashSet.lock()->insert(ripple::to_string(tx.getTransactionID()));
auto const json = tx.getJson(ripple::JsonOptions::none);
auto const txType = json["TransactionType"].asString();
backend->writeTxIndexExample(uint256ToString(tx.getTransactionID()), txType);
}
)
);
scaner.wait();
count = hashSet.size();
scanner.wait();
count = hashSet.lock()->size();
}

View File

@@ -26,6 +26,7 @@
#include "util/Assert.hpp"
#include "util/LoggerFixtures.hpp"
#include "util/MockXrpLedgerAPIService.hpp"
#include "util/Mutex.hpp"
#include "util/TestObject.hpp"
#include <gmock/gmock.h>
@@ -69,9 +70,9 @@ struct GrpcSourceNgTests : NoLoggerFixture, tests::util::WithMockXrpLedgerAPISer
class KeyStore {
std::vector<ripple::uint256> keys_;
std::map<std::string, std::queue<ripple::uint256>, std::greater<>> store_;
using Store = std::map<std::string, std::queue<ripple::uint256>, std::greater<>>;
std::mutex mtx_;
util::Mutex<Store> store_;
public:
KeyStore(std::size_t totalKeys, std::size_t numMarkers) : keys_(etl::getMarkers(totalKeys))
@@ -79,10 +80,11 @@ struct GrpcSourceNgTests : NoLoggerFixture, tests::util::WithMockXrpLedgerAPISer
auto const totalPerMarker = totalKeys / numMarkers;
auto const markers = etl::getMarkers(numMarkers);
auto store = store_.lock();
for (auto mi = 0uz; mi < markers.size(); ++mi) {
for (auto i = 0uz; i < totalPerMarker; ++i) {
auto const mapKey = ripple::strHex(markers.at(mi)).substr(0, 2);
store_[mapKey].push(keys_.at((mi * totalPerMarker) + i));
store->operator[](mapKey).push(keys_.at((mi * totalPerMarker) + i));
}
}
}
@@ -90,11 +92,11 @@ struct GrpcSourceNgTests : NoLoggerFixture, tests::util::WithMockXrpLedgerAPISer
std::optional<std::string>
next(std::string const& marker)
{
std::scoped_lock const lock(mtx_);
auto store = store_.lock<std::scoped_lock>();
auto const mapKey = ripple::strHex(marker).substr(0, 2);
auto it = store_.lower_bound(mapKey);
ASSERT(it != store_.end(), "Lower bound not found for '{}'", mapKey);
auto it = store->lower_bound(mapKey);
ASSERT(it != store->end(), "Lower bound not found for '{}'", mapKey);
auto& queue = it->second;
if (queue.empty())
@@ -109,11 +111,11 @@ struct GrpcSourceNgTests : NoLoggerFixture, tests::util::WithMockXrpLedgerAPISer
std::optional<std::string>
peek(std::string const& marker)
{
std::scoped_lock const lock(mtx_);
auto store = store_.lock<std::scoped_lock>();
auto const mapKey = ripple::strHex(marker).substr(0, 2);
auto it = store_.lower_bound(mapKey);
ASSERT(it != store_.end(), "Lower bound not found for '{}'", mapKey);
auto it = store->lower_bound(mapKey);
ASSERT(it != store->end(), "Lower bound not found for '{}'", mapKey);
auto& queue = it->second;
if (queue.empty())

View File

@@ -29,8 +29,8 @@
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include <atomic>
#include <condition_variable>
#include <cstdint>
#include <mutex>
#include <semaphore>
@@ -53,14 +53,11 @@ struct WorkQueueTest : WithPrometheus, RPCWorkQueueTestBase {};
TEST_F(WorkQueueTest, WhitelistedExecutionCountAddsUp)
{
static constexpr auto kTOTAL = 512u;
uint32_t executeCount = 0u;
std::mutex mtx;
std::atomic_uint32_t executeCount = 0u;
for (auto i = 0u; i < kTOTAL; ++i) {
queue.postCoro(
[&executeCount, &mtx](auto /* yield */) {
std::lock_guard const lk(mtx);
[&executeCount](auto /* yield */) {
++executeCount;
},
true