refactor: Use mutex from utils (#1851)

Fixes #1359.
This commit is contained in:
Sergey Kuznetsov
2025-01-27 15:28:00 +00:00
committed by GitHub
parent 6ef6ca9e65
commit 540e938223
11 changed files with 112 additions and 107 deletions

View File

@@ -23,6 +23,7 @@
#include "data/cassandra/Handle.hpp" #include "data/cassandra/Handle.hpp"
#include "data/cassandra/Types.hpp" #include "data/cassandra/Types.hpp"
#include "data/cassandra/impl/RetryPolicy.hpp" #include "data/cassandra/impl/RetryPolicy.hpp"
#include "util/Mutex.hpp"
#include "util/log/Logger.hpp" #include "util/log/Logger.hpp"
#include <boost/asio.hpp> #include <boost/asio.hpp>
@@ -64,8 +65,8 @@ class AsyncExecutor : public std::enable_shared_from_this<AsyncExecutor<Statemen
RetryCallbackType onRetry_; RetryCallbackType onRetry_;
// does not exist during initial construction, hence optional // does not exist during initial construction, hence optional
std::optional<FutureWithCallbackType> future_; using OptionalFuture = std::optional<FutureWithCallbackType>;
std::mutex mtx_; util::Mutex<OptionalFuture> future_;
public: public:
/** /**
@@ -127,8 +128,8 @@ private:
self = nullptr; // explicitly decrement refcount self = nullptr; // explicitly decrement refcount
}; };
std::scoped_lock const lck{mtx_}; auto future = future_.template lock<std::scoped_lock>();
future_.emplace(handle.asyncExecute(data_, std::move(handler))); future->emplace(handle.asyncExecute(data_, std::move(handler)));
} }
}; };

View File

@@ -19,6 +19,8 @@
#pragma once #pragma once
#include "util/Mutex.hpp"
#include <boost/signals2.hpp> #include <boost/signals2.hpp>
#include <boost/signals2/connection.hpp> #include <boost/signals2/connection.hpp>
#include <boost/signals2/variadic_signal.hpp> #include <boost/signals2/variadic_signal.hpp>
@@ -45,8 +47,8 @@ class TrackableSignal {
// map of connection and signal connection, key is the pointer of the connection object // map of connection and signal connection, key is the pointer of the connection object
// allow disconnect to be called in the destructor of the connection // allow disconnect to be called in the destructor of the connection
std::unordered_map<ConnectionPtr, boost::signals2::connection> connections_; using ConnectionsMap = std::unordered_map<ConnectionPtr, boost::signals2::connection>;
mutable std::mutex mutex_; util::Mutex<ConnectionsMap> connections_;
using SignalType = boost::signals2::signal<void(Args...)>; using SignalType = boost::signals2::signal<void(Args...)>;
SignalType signal_; SignalType signal_;
@@ -64,8 +66,8 @@ public:
bool bool
connectTrackableSlot(ConnectionSharedPtr const& trackable, std::function<void(Args...)> slot) connectTrackableSlot(ConnectionSharedPtr const& trackable, std::function<void(Args...)> slot)
{ {
std::scoped_lock const lk(mutex_); auto connections = connections_.template lock<std::scoped_lock>();
if (connections_.contains(trackable.get())) { if (connections->contains(trackable.get())) {
return false; return false;
} }
@@ -73,7 +75,7 @@ public:
// the trackable's destructor. However, the trackable can not be destroied when the slot is being called // the trackable's destructor. However, the trackable can not be destroied when the slot is being called
// either. track_foreign will hold a weak_ptr to the connection, which makes sure the connection is valid when // either. track_foreign will hold a weak_ptr to the connection, which makes sure the connection is valid when
// the slot is called. // the slot is called.
connections_.emplace( connections->emplace(
trackable.get(), signal_.connect(typename SignalType::slot_type(slot).track_foreign(trackable)) trackable.get(), signal_.connect(typename SignalType::slot_type(slot).track_foreign(trackable))
); );
return true; return true;
@@ -89,10 +91,9 @@ public:
bool bool
disconnect(ConnectionPtr trackablePtr) disconnect(ConnectionPtr trackablePtr)
{ {
std::scoped_lock const lk(mutex_); if (auto connections = connections_.template lock<std::scoped_lock>(); connections->contains(trackablePtr)) {
if (connections_.contains(trackablePtr)) { connections->operator[](trackablePtr).disconnect();
connections_[trackablePtr].disconnect(); connections->erase(trackablePtr);
connections_.erase(trackablePtr);
return true; return true;
} }
return false; return false;
@@ -115,8 +116,7 @@ public:
std::size_t std::size_t
count() const count() const
{ {
std::scoped_lock const lk(mutex_); return connections_.template lock<std::scoped_lock>()->size();
return connections_.size();
} }
}; };
} // namespace feed::impl } // namespace feed::impl

View File

@@ -20,6 +20,7 @@
#pragma once #pragma once
#include "feed/impl/TrackableSignal.hpp" #include "feed/impl/TrackableSignal.hpp"
#include "util/Mutex.hpp"
#include <boost/signals2.hpp> #include <boost/signals2.hpp>
@@ -49,8 +50,8 @@ class TrackableSignalMap {
using ConnectionPtr = Session*; using ConnectionPtr = Session*;
using ConnectionSharedPtr = std::shared_ptr<Session>; using ConnectionSharedPtr = std::shared_ptr<Session>;
mutable std::mutex mutex_; using SignalsMap = std::unordered_map<Key, TrackableSignal<Session, Args...>>;
std::unordered_map<Key, TrackableSignal<Session, Args...>> signalsMap_; util::Mutex<SignalsMap> signalsMap_;
public: public:
/** /**
@@ -66,8 +67,8 @@ public:
bool bool
connectTrackableSlot(ConnectionSharedPtr const& trackable, Key const& key, std::function<void(Args...)> slot) connectTrackableSlot(ConnectionSharedPtr const& trackable, Key const& key, std::function<void(Args...)> slot)
{ {
std::scoped_lock const lk(mutex_); auto map = signalsMap_.template lock<std::scoped_lock>();
return signalsMap_[key].connectTrackableSlot(trackable, slot); return map->operator[](key).connectTrackableSlot(trackable, slot);
} }
/** /**
@@ -80,14 +81,14 @@ public:
bool bool
disconnect(ConnectionPtr trackablePtr, Key const& key) disconnect(ConnectionPtr trackablePtr, Key const& key)
{ {
std::scoped_lock const lk(mutex_); auto map = signalsMap_.template lock<std::scoped_lock>();
if (!signalsMap_.contains(key)) if (!map->contains(key))
return false; return false;
auto const disconnected = signalsMap_[key].disconnect(trackablePtr); auto const disconnected = map->operator[](key).disconnect(trackablePtr);
// clean the map if there is no connection left. // clean the map if there is no connection left.
if (disconnected && signalsMap_[key].count() == 0) if (disconnected && map->operator[](key).count() == 0)
signalsMap_.erase(key); map->erase(key);
return disconnected; return disconnected;
} }
@@ -101,9 +102,9 @@ public:
void void
emit(Key const& key, Args const&... args) emit(Key const& key, Args const&... args)
{ {
std::scoped_lock const lk(mutex_); auto map = signalsMap_.template lock<std::scoped_lock>();
if (signalsMap_.contains(key)) if (map->contains(key))
signalsMap_[key].emit(args...); map->operator[](key).emit(args...);
} }
}; };
} // namespace feed::impl } // namespace feed::impl

View File

@@ -34,7 +34,7 @@ class Mutex;
* @tparam LockType type of lock * @tparam LockType type of lock
* @tparam MutexType type of mutex * @tparam MutexType type of mutex
*/ */
template <typename ProtectedDataType, template <typename> typename LockType, typename MutexType> template <typename ProtectedDataType, template <typename...> typename LockType, typename MutexType>
class Lock { class Lock {
LockType<MutexType> lock_; LockType<MutexType> lock_;
ProtectedDataType& data_; ProtectedDataType& data_;
@@ -129,7 +129,7 @@ public:
* @tparam LockType The type of lock to use * @tparam LockType The type of lock to use
* @return A lock on the mutex and a reference to the protected data * @return A lock on the mutex and a reference to the protected data
*/ */
template <template <typename> typename LockType = std::lock_guard> template <template <typename...> typename LockType = std::lock_guard>
Lock<ProtectedDataType const, LockType, MutexType> Lock<ProtectedDataType const, LockType, MutexType>
lock() const lock() const
{ {
@@ -142,7 +142,7 @@ public:
* @tparam LockType The type of lock to use * @tparam LockType The type of lock to use
* @return A lock on the mutex and a reference to the protected data * @return A lock on the mutex and a reference to the protected data
*/ */
template <template <typename> typename LockType = std::lock_guard> template <template <typename...> typename LockType = std::lock_guard>
Lock<ProtectedDataType, LockType, MutexType> Lock<ProtectedDataType, LockType, MutexType>
lock() lock()
{ {

View File

@@ -21,6 +21,7 @@
#include "util/Assert.hpp" #include "util/Assert.hpp"
#include "util/Concepts.hpp" #include "util/Concepts.hpp"
#include "util/Mutex.hpp"
#include "util/prometheus/OStream.hpp" #include "util/prometheus/OStream.hpp"
#include <cstdint> #include <cstdint>
@@ -61,28 +62,30 @@ public:
void void
setBuckets(std::vector<ValueType> const& bounds) setBuckets(std::vector<ValueType> const& bounds)
{ {
std::scoped_lock const lock{*mutex_}; auto data = data_->template lock<std::scoped_lock>();
ASSERT(buckets_.empty(), "Buckets can be set only once."); ASSERT(data->buckets.empty(), "Buckets can be set only once.");
buckets_.reserve(bounds.size()); data->buckets.reserve(bounds.size());
for (auto const& bound : bounds) { for (auto const& bound : bounds) {
buckets_.emplace_back(bound); data->buckets.emplace_back(bound);
} }
} }
void void
observe(ValueType const value) observe(ValueType const value)
{ {
auto const bucket = auto data = data_->template lock<std::scoped_lock>();
std::lower_bound(buckets_.begin(), buckets_.end(), value, [](Bucket const& bucket, ValueType const& value) { auto const bucket = std::lower_bound(
return bucket.upperBound < value; data->buckets.begin(),
}); data->buckets.end(),
std::scoped_lock const lock{*mutex_}; value,
if (bucket != buckets_.end()) { [](Bucket const& bucket, ValueType const& value) { return bucket.upperBound < value; }
);
if (bucket != data->buckets.end()) {
++bucket->count; ++bucket->count;
} else { } else {
++lastBucket_.count; ++data->lastBucket.count;
} }
sum_ += value; data->sum += value;
} }
void void
@@ -98,15 +101,15 @@ public:
labelsString.back() = ','; labelsString.back() = ',';
} }
std::scoped_lock const lock{*mutex_}; auto data = data_->template lock<std::scoped_lock>();
std::uint64_t cumulativeCount = 0; std::uint64_t cumulativeCount = 0;
for (auto const& bucket : buckets_) { for (auto const& bucket : data->buckets) {
cumulativeCount += bucket.count; cumulativeCount += bucket.count;
stream << name << "_bucket" << labelsString << "le=\"" << bucket.upperBound << "\"} " << cumulativeCount stream << name << "_bucket" << labelsString << "le=\"" << bucket.upperBound << "\"} " << cumulativeCount
<< '\n'; << '\n';
} }
cumulativeCount += lastBucket_.count; cumulativeCount += data->lastBucket.count;
stream << name << "_bucket" << labelsString << "le=\"+Inf\"} " << cumulativeCount << '\n'; stream << name << "_bucket" << labelsString << "le=\"+Inf\"} " << cumulativeCount << '\n';
if (labelsString.size() == 1) { if (labelsString.size() == 1) {
@@ -114,7 +117,7 @@ public:
} else { } else {
labelsString.back() = '}'; labelsString.back() = '}';
} }
stream << name << "_sum" << labelsString << " " << sum_ << '\n'; stream << name << "_sum" << labelsString << " " << data->sum << '\n';
stream << name << "_count" << labelsString << " " << cumulativeCount << '\n'; stream << name << "_count" << labelsString << " " << cumulativeCount << '\n';
} }
@@ -128,10 +131,12 @@ private:
std::uint64_t count = 0; std::uint64_t count = 0;
}; };
std::vector<Bucket> buckets_; struct Data {
Bucket lastBucket_{std::numeric_limits<ValueType>::max()}; std::vector<Bucket> buckets;
ValueType sum_ = 0; Bucket lastBucket{std::numeric_limits<ValueType>::max()};
mutable std::unique_ptr<std::mutex> mutex_ = std::make_unique<std::mutex>(); ValueType sum = 0;
};
std::unique_ptr<util::Mutex<Data>> data_ = std::make_unique<util::Mutex<Data>>();
}; };
} // namespace util::prometheus::impl } // namespace util::prometheus::impl

View File

@@ -58,17 +58,17 @@ DOSGuard::isOk(std::string const& ip) const noexcept
return true; return true;
{ {
std::scoped_lock const lck(mtx_); auto lock = mtx_.lock<std::scoped_lock>();
if (ipState_.find(ip) != ipState_.end()) { if (lock->ipState.find(ip) != lock->ipState.end()) {
auto [transferedByte, requests] = ipState_.at(ip); auto [transferredByte, requests] = lock->ipState.at(ip);
if (transferedByte > maxFetches_ || requests > maxRequestCount_) { if (transferredByte > maxFetches_ || requests > maxRequestCount_) {
LOG(log_.warn()) << "Dosguard: Client surpassed the rate limit. ip = " << ip LOG(log_.warn()) << "Dosguard: Client surpassed the rate limit. ip = " << ip
<< " Transfered Byte: " << transferedByte << "; Requests: " << requests; << " Transfered Byte: " << transferredByte << "; Requests: " << requests;
return false; return false;
} }
} }
auto it = ipConnCount_.find(ip); auto it = lock->ipConnCount.find(ip);
if (it != ipConnCount_.end()) { if (it != lock->ipConnCount.end()) {
if (it->second > maxConnCount_) { if (it->second > maxConnCount_) {
LOG(log_.warn()) << "Dosguard: Client surpassed the rate limit. ip = " << ip LOG(log_.warn()) << "Dosguard: Client surpassed the rate limit. ip = " << ip
<< " Concurrent connection: " << it->second; << " Concurrent connection: " << it->second;
@@ -84,8 +84,8 @@ DOSGuard::increment(std::string const& ip) noexcept
{ {
if (whitelistHandler_.get().isWhiteListed(ip)) if (whitelistHandler_.get().isWhiteListed(ip))
return; return;
std::scoped_lock const lck{mtx_}; auto lock = mtx_.lock<std::scoped_lock>();
ipConnCount_[ip]++; lock->ipConnCount[ip]++;
} }
void void
@@ -93,11 +93,11 @@ DOSGuard::decrement(std::string const& ip) noexcept
{ {
if (whitelistHandler_.get().isWhiteListed(ip)) if (whitelistHandler_.get().isWhiteListed(ip))
return; return;
std::scoped_lock const lck{mtx_}; auto lock = mtx_.lock<std::scoped_lock>();
ASSERT(ipConnCount_[ip] > 0, "Connection count for ip {} can't be 0", ip); ASSERT(lock->ipConnCount[ip] > 0, "Connection count for ip {} can't be 0", ip);
ipConnCount_[ip]--; lock->ipConnCount[ip]--;
if (ipConnCount_[ip] == 0) if (lock->ipConnCount[ip] == 0)
ipConnCount_.erase(ip); lock->ipConnCount.erase(ip);
} }
[[maybe_unused]] bool [[maybe_unused]] bool
@@ -107,8 +107,8 @@ DOSGuard::add(std::string const& ip, uint32_t numObjects) noexcept
return true; return true;
{ {
std::scoped_lock const lck(mtx_); auto lock = mtx_.lock<std::scoped_lock>();
ipState_[ip].transferedByte += numObjects; lock->ipState[ip].transferedByte += numObjects;
} }
return isOk(ip); return isOk(ip);
@@ -121,8 +121,8 @@ DOSGuard::request(std::string const& ip) noexcept
return true; return true;
{ {
std::scoped_lock const lck(mtx_); auto lock = mtx_.lock<std::scoped_lock>();
ipState_[ip].requestsCount++; lock->ipState[ip].requestsCount++;
} }
return isOk(ip); return isOk(ip);
@@ -131,8 +131,8 @@ DOSGuard::request(std::string const& ip) noexcept
void void
DOSGuard::clear() noexcept DOSGuard::clear() noexcept
{ {
std::scoped_lock const lck(mtx_); auto lock = mtx_.lock<std::scoped_lock>();
ipState_.clear(); lock->ipState.clear();
} }
[[nodiscard]] std::unordered_set<std::string> [[nodiscard]] std::unordered_set<std::string>

View File

@@ -19,6 +19,7 @@
#pragma once #pragma once
#include "util/Mutex.hpp"
#include "util/log/Logger.hpp" #include "util/log/Logger.hpp"
#include "util/newconfig/ConfigDefinition.hpp" #include "util/newconfig/ConfigDefinition.hpp"
#include "web/dosguard/DOSGuardInterface.hpp" #include "web/dosguard/DOSGuardInterface.hpp"
@@ -30,7 +31,6 @@
#include <cstdint> #include <cstdint>
#include <functional> #include <functional>
#include <mutex>
#include <string> #include <string>
#include <string_view> #include <string_view>
#include <unordered_map> #include <unordered_map>
@@ -52,9 +52,12 @@ class DOSGuard : public DOSGuardInterface {
std::uint32_t requestsCount = 0; /**< Accumulated served requests count */ std::uint32_t requestsCount = 0; /**< Accumulated served requests count */
}; };
mutable std::mutex mtx_; struct State {
std::unordered_map<std::string, ClientState> ipState_; std::unordered_map<std::string, ClientState> ipState;
std::unordered_map<std::string, std::uint32_t> ipConnCount_; std::unordered_map<std::string, std::uint32_t> ipConnCount;
};
util::Mutex<State> mtx_;
std::reference_wrapper<WhitelistHandlerInterface const> whitelistHandler_; std::reference_wrapper<WhitelistHandlerInterface const> whitelistHandler_;
std::uint32_t const maxFetches_; std::uint32_t const maxFetches_;

View File

@@ -20,6 +20,7 @@
#pragma once #pragma once
#include "data/Types.hpp" #include "data/Types.hpp"
#include "util/Mutex.hpp"
#include <xrpl/basics/Blob.h> #include <xrpl/basics/Blob.h>
#include <xrpl/basics/base_uint.h> #include <xrpl/basics/base_uint.h>
@@ -27,7 +28,6 @@
#include <cstddef> #include <cstddef>
#include <cstdint> #include <cstdint>
#include <map> #include <map>
#include <mutex>
#include <optional> #include <optional>
#include <thread> #include <thread>
#include <vector> #include <vector>
@@ -57,14 +57,14 @@ struct DiffProvider {
nextKey(std::size_t keysSize) nextKey(std::size_t keysSize)
{ {
// mock the result from doFetchSuccessorKey, be aware this function will be called from multiple threads // mock the result from doFetchSuccessorKey, be aware this function will be called from multiple threads
std::lock_guard<std::mutex> const guard(keysMutex_); auto keysMap = threadKeysMap_.lock();
threadKeysMap_[std::this_thread::get_id()]++; keysMap->operator[](std::this_thread::get_id())++;
if (threadKeysMap_[std::this_thread::get_id()] == keysSize - 1) { if (keysMap->operator[](std::this_thread::get_id()) == keysSize - 1) {
return data::kLAST_KEY; return data::kLAST_KEY;
} }
if (threadKeysMap_[std::this_thread::get_id()] == keysSize) { if (keysMap->operator[](std::this_thread::get_id()) == keysSize) {
threadKeysMap_[std::this_thread::get_id()] = 0; keysMap->operator[](std::this_thread::get_id()) = 0;
return std::nullopt; return std::nullopt;
} }
@@ -72,6 +72,5 @@ struct DiffProvider {
} }
private: private:
std::mutex keysMutex_; util::Mutex<std::map<std::thread::id, uint32_t>> threadKeysMap_;
std::map<std::thread::id, uint32_t> threadKeysMap_;
}; };

View File

@@ -22,6 +22,7 @@
#include "data/DBHelpers.hpp" #include "data/DBHelpers.hpp"
#include "migration/cassandra/impl/TransactionsAdapter.hpp" #include "migration/cassandra/impl/TransactionsAdapter.hpp"
#include "migration/cassandra/impl/Types.hpp" #include "migration/cassandra/impl/Types.hpp"
#include "util/Mutex.hpp"
#include "util/newconfig/ObjectView.hpp" #include "util/newconfig/ObjectView.hpp"
#include <xrpl/basics/base_uint.h> #include <xrpl/basics/base_uint.h>
@@ -31,7 +32,6 @@
#include <cstdint> #include <cstdint>
#include <memory> #include <memory>
#include <mutex>
#include <string> #include <string>
#include <unordered_set> #include <unordered_set>
@@ -47,23 +47,20 @@ ExampleTransactionsMigrator::runMigration(
auto const jobsFullScan = config.get<std::uint32_t>("full_scan_jobs"); auto const jobsFullScan = config.get<std::uint32_t>("full_scan_jobs");
auto const cursorPerJobsFullScan = config.get<std::uint32_t>("cursors_per_job"); auto const cursorPerJobsFullScan = config.get<std::uint32_t>("cursors_per_job");
std::unordered_set<std::string> hashSet; using HashSet = std::unordered_set<std::string>;
std::mutex mtx; // protect hashSet util::Mutex<HashSet> hashSet;
migration::cassandra::impl::TransactionsScanner scaner( migration::cassandra::impl::TransactionsScanner scanner(
{.ctxThreadsNum = ctxFullScanThreads, .jobsNum = jobsFullScan, .cursorsPerJob = cursorPerJobsFullScan}, {.ctxThreadsNum = ctxFullScanThreads, .jobsNum = jobsFullScan, .cursorsPerJob = cursorPerJobsFullScan},
migration::cassandra::impl::TransactionsAdapter( migration::cassandra::impl::TransactionsAdapter(
backend, backend,
[&](ripple::STTx const& tx, ripple::TxMeta const&) { [&](ripple::STTx const& tx, ripple::TxMeta const&) {
{ hashSet.lock()->insert(ripple::to_string(tx.getTransactionID()));
std::lock_guard<std::mutex> const lock(mtx);
hashSet.insert(ripple::to_string(tx.getTransactionID()));
}
auto const json = tx.getJson(ripple::JsonOptions::none); auto const json = tx.getJson(ripple::JsonOptions::none);
auto const txType = json["TransactionType"].asString(); auto const txType = json["TransactionType"].asString();
backend->writeTxIndexExample(uint256ToString(tx.getTransactionID()), txType); backend->writeTxIndexExample(uint256ToString(tx.getTransactionID()), txType);
} }
) )
); );
scaner.wait(); scanner.wait();
count = hashSet.size(); count = hashSet.lock()->size();
} }

View File

@@ -26,6 +26,7 @@
#include "util/Assert.hpp" #include "util/Assert.hpp"
#include "util/LoggerFixtures.hpp" #include "util/LoggerFixtures.hpp"
#include "util/MockXrpLedgerAPIService.hpp" #include "util/MockXrpLedgerAPIService.hpp"
#include "util/Mutex.hpp"
#include "util/TestObject.hpp" #include "util/TestObject.hpp"
#include <gmock/gmock.h> #include <gmock/gmock.h>
@@ -69,9 +70,9 @@ struct GrpcSourceNgTests : NoLoggerFixture, tests::util::WithMockXrpLedgerAPISer
class KeyStore { class KeyStore {
std::vector<ripple::uint256> keys_; std::vector<ripple::uint256> keys_;
std::map<std::string, std::queue<ripple::uint256>, std::greater<>> store_; using Store = std::map<std::string, std::queue<ripple::uint256>, std::greater<>>;
std::mutex mtx_; util::Mutex<Store> store_;
public: public:
KeyStore(std::size_t totalKeys, std::size_t numMarkers) : keys_(etl::getMarkers(totalKeys)) KeyStore(std::size_t totalKeys, std::size_t numMarkers) : keys_(etl::getMarkers(totalKeys))
@@ -79,10 +80,11 @@ struct GrpcSourceNgTests : NoLoggerFixture, tests::util::WithMockXrpLedgerAPISer
auto const totalPerMarker = totalKeys / numMarkers; auto const totalPerMarker = totalKeys / numMarkers;
auto const markers = etl::getMarkers(numMarkers); auto const markers = etl::getMarkers(numMarkers);
auto store = store_.lock();
for (auto mi = 0uz; mi < markers.size(); ++mi) { for (auto mi = 0uz; mi < markers.size(); ++mi) {
for (auto i = 0uz; i < totalPerMarker; ++i) { for (auto i = 0uz; i < totalPerMarker; ++i) {
auto const mapKey = ripple::strHex(markers.at(mi)).substr(0, 2); auto const mapKey = ripple::strHex(markers.at(mi)).substr(0, 2);
store_[mapKey].push(keys_.at((mi * totalPerMarker) + i)); store->operator[](mapKey).push(keys_.at((mi * totalPerMarker) + i));
} }
} }
} }
@@ -90,11 +92,11 @@ struct GrpcSourceNgTests : NoLoggerFixture, tests::util::WithMockXrpLedgerAPISer
std::optional<std::string> std::optional<std::string>
next(std::string const& marker) next(std::string const& marker)
{ {
std::scoped_lock const lock(mtx_); auto store = store_.lock<std::scoped_lock>();
auto const mapKey = ripple::strHex(marker).substr(0, 2); auto const mapKey = ripple::strHex(marker).substr(0, 2);
auto it = store_.lower_bound(mapKey); auto it = store->lower_bound(mapKey);
ASSERT(it != store_.end(), "Lower bound not found for '{}'", mapKey); ASSERT(it != store->end(), "Lower bound not found for '{}'", mapKey);
auto& queue = it->second; auto& queue = it->second;
if (queue.empty()) if (queue.empty())
@@ -109,11 +111,11 @@ struct GrpcSourceNgTests : NoLoggerFixture, tests::util::WithMockXrpLedgerAPISer
std::optional<std::string> std::optional<std::string>
peek(std::string const& marker) peek(std::string const& marker)
{ {
std::scoped_lock const lock(mtx_); auto store = store_.lock<std::scoped_lock>();
auto const mapKey = ripple::strHex(marker).substr(0, 2); auto const mapKey = ripple::strHex(marker).substr(0, 2);
auto it = store_.lower_bound(mapKey); auto it = store->lower_bound(mapKey);
ASSERT(it != store_.end(), "Lower bound not found for '{}'", mapKey); ASSERT(it != store->end(), "Lower bound not found for '{}'", mapKey);
auto& queue = it->second; auto& queue = it->second;
if (queue.empty()) if (queue.empty())

View File

@@ -29,8 +29,8 @@
#include <gmock/gmock.h> #include <gmock/gmock.h>
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <atomic>
#include <condition_variable> #include <condition_variable>
#include <cstdint>
#include <mutex> #include <mutex>
#include <semaphore> #include <semaphore>
@@ -53,14 +53,11 @@ struct WorkQueueTest : WithPrometheus, RPCWorkQueueTestBase {};
TEST_F(WorkQueueTest, WhitelistedExecutionCountAddsUp) TEST_F(WorkQueueTest, WhitelistedExecutionCountAddsUp)
{ {
static constexpr auto kTOTAL = 512u; static constexpr auto kTOTAL = 512u;
uint32_t executeCount = 0u; std::atomic_uint32_t executeCount = 0u;
std::mutex mtx;
for (auto i = 0u; i < kTOTAL; ++i) { for (auto i = 0u; i < kTOTAL; ++i) {
queue.postCoro( queue.postCoro(
[&executeCount, &mtx](auto /* yield */) { [&executeCount](auto /* yield */) {
std::lock_guard const lk(mtx);
++executeCount; ++executeCount;
}, },
true true