fix: Use UniformRandomGenerator class to prevent threading issue (#2165)

This commit is contained in:
Ayaz Salikhov
2025-05-29 19:59:40 +01:00
committed by GitHub
parent 9b69da7f91
commit 57b8ff1c49
13 changed files with 174 additions and 48 deletions

View File

@@ -191,8 +191,9 @@ generateData()
constexpr auto kTOTAL = 10'000; constexpr auto kTOTAL = 10'000;
std::vector<uint64_t> data; std::vector<uint64_t> data;
data.reserve(kTOTAL); data.reserve(kTOTAL);
util::MTRandomGenerator randomGenerator;
for (auto i = 0; i < kTOTAL; ++i) for (auto i = 0; i < kTOTAL; ++i)
data.push_back(util::Random::uniform(1, 100'000'000)); data.push_back(randomGenerator.uniform(1, 100'000'000));
return data; return data;
} }

View File

@@ -145,9 +145,13 @@ ClioApplication::run(bool const useNgWebServer)
// The balancer itself publishes to streams (transactions_proposed and accounts_proposed) // The balancer itself publishes to streams (transactions_proposed and accounts_proposed)
auto balancer = [&] -> std::shared_ptr<etlng::LoadBalancerInterface> { auto balancer = [&] -> std::shared_ptr<etlng::LoadBalancerInterface> {
if (config_.get<bool>("__ng_etl")) if (config_.get<bool>("__ng_etl"))
return etlng::LoadBalancer::makeLoadBalancer(config_, ioc, backend, subscriptions, ledgers); return etlng::LoadBalancer::makeLoadBalancer(
config_, ioc, backend, subscriptions, std::make_unique<util::MTRandomGenerator>(), ledgers
);
return etl::LoadBalancer::makeLoadBalancer(config_, ioc, backend, subscriptions, ledgers); return etl::LoadBalancer::makeLoadBalancer(
config_, ioc, backend, subscriptions, std::make_unique<util::MTRandomGenerator>(), ledgers
);
}(); }();
// ETL is responsible for writing and publishing to streams. In read-only mode, ETL only publishes // ETL is responsible for writing and publishing to streams. In read-only mode, ETL only publishes

View File

@@ -71,12 +71,19 @@ LoadBalancer::makeLoadBalancer(
boost::asio::io_context& ioc, boost::asio::io_context& ioc,
std::shared_ptr<BackendInterface> backend, std::shared_ptr<BackendInterface> backend,
std::shared_ptr<feed::SubscriptionManagerInterface> subscriptions, std::shared_ptr<feed::SubscriptionManagerInterface> subscriptions,
std::unique_ptr<util::RandomGeneratorInterface> randomGenerator,
std::shared_ptr<NetworkValidatedLedgersInterface> validatedLedgers, std::shared_ptr<NetworkValidatedLedgersInterface> validatedLedgers,
SourceFactory sourceFactory SourceFactory sourceFactory
) )
{ {
return std::make_shared<LoadBalancer>( return std::make_shared<LoadBalancer>(
config, ioc, std::move(backend), std::move(subscriptions), std::move(validatedLedgers), std::move(sourceFactory) config,
ioc,
std::move(backend),
std::move(subscriptions),
std::move(randomGenerator),
std::move(validatedLedgers),
std::move(sourceFactory)
); );
} }
@@ -85,10 +92,12 @@ LoadBalancer::LoadBalancer(
boost::asio::io_context& ioc, boost::asio::io_context& ioc,
std::shared_ptr<BackendInterface> backend, std::shared_ptr<BackendInterface> backend,
std::shared_ptr<feed::SubscriptionManagerInterface> subscriptions, std::shared_ptr<feed::SubscriptionManagerInterface> subscriptions,
std::unique_ptr<util::RandomGeneratorInterface> randomGenerator,
std::shared_ptr<NetworkValidatedLedgersInterface> validatedLedgers, std::shared_ptr<NetworkValidatedLedgersInterface> validatedLedgers,
SourceFactory sourceFactory SourceFactory sourceFactory
) )
: forwardingCounters_{ : randomGenerator_(std::move(randomGenerator))
, forwardingCounters_{
.successDuration = PrometheusService::counterInt( .successDuration = PrometheusService::counterInt(
"forwarding_duration_milliseconds_counter", "forwarding_duration_milliseconds_counter",
Labels({util::prometheus::Label{"status", "success"}}), Labels({util::prometheus::Label{"status", "success"}}),
@@ -319,7 +328,7 @@ void
LoadBalancer::execute(Func f, uint32_t ledgerSequence, std::chrono::steady_clock::duration retryAfter) LoadBalancer::execute(Func f, uint32_t ledgerSequence, std::chrono::steady_clock::duration retryAfter)
{ {
ASSERT(not sources_.empty(), "ETL sources must be configured to execute functions."); ASSERT(not sources_.empty(), "ETL sources must be configured to execute functions.");
size_t sourceIdx = util::Random::uniform(0ul, sources_.size() - 1); size_t sourceIdx = randomGenerator_->uniform(0ul, sources_.size() - 1);
size_t numAttempts = 0; size_t numAttempts = 0;
@@ -403,7 +412,7 @@ LoadBalancer::forwardToRippledImpl(
++forwardingCounters_.cacheMiss.get(); ++forwardingCounters_.cacheMiss.get();
ASSERT(not sources_.empty(), "ETL sources must be configured to forward requests."); ASSERT(not sources_.empty(), "ETL sources must be configured to forward requests.");
std::size_t sourceIdx = util::Random::uniform(0ul, sources_.size() - 1); std::size_t sourceIdx = randomGenerator_->uniform(0ul, sources_.size() - 1);
auto numAttempts = 0u; auto numAttempts = 0u;

View File

@@ -29,6 +29,7 @@
#include "rpc/Errors.hpp" #include "rpc/Errors.hpp"
#include "util/Assert.hpp" #include "util/Assert.hpp"
#include "util/Mutex.hpp" #include "util/Mutex.hpp"
#include "util/Random.hpp"
#include "util/ResponseExpirationCache.hpp" #include "util/ResponseExpirationCache.hpp"
#include "util/config/ConfigDefinition.hpp" #include "util/config/ConfigDefinition.hpp"
#include "util/log/Logger.hpp" #include "util/log/Logger.hpp"
@@ -88,6 +89,8 @@ private:
std::optional<util::ResponseExpirationCache> forwardingCache_; std::optional<util::ResponseExpirationCache> forwardingCache_;
std::optional<std::string> forwardingXUserValue_; std::optional<std::string> forwardingXUserValue_;
std::unique_ptr<util::RandomGeneratorInterface> randomGenerator_;
std::vector<SourcePtr> sources_; std::vector<SourcePtr> sources_;
std::optional<ETLState> etlState_; std::optional<ETLState> etlState_;
std::uint32_t downloadRanges_ = std::uint32_t downloadRanges_ =
@@ -123,6 +126,7 @@ public:
* @param ioc The io_context to run on * @param ioc The io_context to run on
* @param backend BackendInterface implementation * @param backend BackendInterface implementation
* @param subscriptions Subscription manager * @param subscriptions Subscription manager
* @param randomGenerator A random generator to use for selecting sources
* @param validatedLedgers The network validated ledgers datastructure * @param validatedLedgers The network validated ledgers datastructure
* @param sourceFactory A factory function to create a source * @param sourceFactory A factory function to create a source
*/ */
@@ -131,6 +135,7 @@ public:
boost::asio::io_context& ioc, boost::asio::io_context& ioc,
std::shared_ptr<BackendInterface> backend, std::shared_ptr<BackendInterface> backend,
std::shared_ptr<feed::SubscriptionManagerInterface> subscriptions, std::shared_ptr<feed::SubscriptionManagerInterface> subscriptions,
std::unique_ptr<util::RandomGeneratorInterface> randomGenerator,
std::shared_ptr<NetworkValidatedLedgersInterface> validatedLedgers, std::shared_ptr<NetworkValidatedLedgersInterface> validatedLedgers,
SourceFactory sourceFactory = makeSource SourceFactory sourceFactory = makeSource
); );
@@ -142,6 +147,7 @@ public:
* @param ioc The io_context to run on * @param ioc The io_context to run on
* @param backend BackendInterface implementation * @param backend BackendInterface implementation
* @param subscriptions Subscription manager * @param subscriptions Subscription manager
* @param randomGenerator A random generator to use for selecting sources
* @param validatedLedgers The network validated ledgers data structure * @param validatedLedgers The network validated ledgers data structure
* @param sourceFactory A factory function to create a source * @param sourceFactory A factory function to create a source
* @return A shared pointer to a new instance of LoadBalancer * @return A shared pointer to a new instance of LoadBalancer
@@ -152,6 +158,7 @@ public:
boost::asio::io_context& ioc, boost::asio::io_context& ioc,
std::shared_ptr<BackendInterface> backend, std::shared_ptr<BackendInterface> backend,
std::shared_ptr<feed::SubscriptionManagerInterface> subscriptions, std::shared_ptr<feed::SubscriptionManagerInterface> subscriptions,
std::unique_ptr<util::RandomGeneratorInterface> randomGenerator,
std::shared_ptr<NetworkValidatedLedgersInterface> validatedLedgers, std::shared_ptr<NetworkValidatedLedgersInterface> validatedLedgers,
SourceFactory sourceFactory = makeSource SourceFactory sourceFactory = makeSource
); );

View File

@@ -72,12 +72,19 @@ LoadBalancer::makeLoadBalancer(
boost::asio::io_context& ioc, boost::asio::io_context& ioc,
std::shared_ptr<BackendInterface> backend, std::shared_ptr<BackendInterface> backend,
std::shared_ptr<feed::SubscriptionManagerInterface> subscriptions, std::shared_ptr<feed::SubscriptionManagerInterface> subscriptions,
std::unique_ptr<util::RandomGeneratorInterface> randomGenerator,
std::shared_ptr<etl::NetworkValidatedLedgersInterface> validatedLedgers, std::shared_ptr<etl::NetworkValidatedLedgersInterface> validatedLedgers,
SourceFactory sourceFactory SourceFactory sourceFactory
) )
{ {
return std::make_shared<LoadBalancer>( return std::make_shared<LoadBalancer>(
config, ioc, std::move(backend), std::move(subscriptions), std::move(validatedLedgers), std::move(sourceFactory) config,
ioc,
std::move(backend),
std::move(subscriptions),
std::move(randomGenerator),
std::move(validatedLedgers),
std::move(sourceFactory)
); );
} }
@@ -86,10 +93,12 @@ LoadBalancer::LoadBalancer(
boost::asio::io_context& ioc, boost::asio::io_context& ioc,
std::shared_ptr<BackendInterface> backend, std::shared_ptr<BackendInterface> backend,
std::shared_ptr<feed::SubscriptionManagerInterface> subscriptions, std::shared_ptr<feed::SubscriptionManagerInterface> subscriptions,
std::unique_ptr<util::RandomGeneratorInterface> randomGenerator,
std::shared_ptr<etl::NetworkValidatedLedgersInterface> validatedLedgers, std::shared_ptr<etl::NetworkValidatedLedgersInterface> validatedLedgers,
SourceFactory sourceFactory SourceFactory sourceFactory
) )
: forwardingCounters_{ : randomGenerator_(std::move(randomGenerator))
, forwardingCounters_{
.successDuration = PrometheusService::counterInt( .successDuration = PrometheusService::counterInt(
"forwarding_duration_milliseconds_counter", "forwarding_duration_milliseconds_counter",
Labels({util::prometheus::Label{"status", "success"}}), Labels({util::prometheus::Label{"status", "success"}}),
@@ -323,7 +332,7 @@ void
LoadBalancer::execute(Func f, uint32_t ledgerSequence, std::chrono::steady_clock::duration retryAfter) LoadBalancer::execute(Func f, uint32_t ledgerSequence, std::chrono::steady_clock::duration retryAfter)
{ {
ASSERT(not sources_.empty(), "ETL sources must be configured to execute functions."); ASSERT(not sources_.empty(), "ETL sources must be configured to execute functions.");
size_t sourceIdx = util::Random::uniform(0ul, sources_.size() - 1); size_t sourceIdx = randomGenerator_->uniform(0ul, sources_.size() - 1);
size_t numAttempts = 0; size_t numAttempts = 0;
@@ -407,7 +416,7 @@ LoadBalancer::forwardToRippledImpl(
++forwardingCounters_.cacheMiss.get(); ++forwardingCounters_.cacheMiss.get();
ASSERT(not sources_.empty(), "ETL sources must be configured to forward requests."); ASSERT(not sources_.empty(), "ETL sources must be configured to forward requests.");
std::size_t sourceIdx = util::Random::uniform(0ul, sources_.size() - 1); std::size_t sourceIdx = randomGenerator_->uniform(0ul, sources_.size() - 1);
auto numAttempts = 0u; auto numAttempts = 0u;

View File

@@ -29,6 +29,7 @@
#include "rpc/Errors.hpp" #include "rpc/Errors.hpp"
#include "util/Assert.hpp" #include "util/Assert.hpp"
#include "util/Mutex.hpp" #include "util/Mutex.hpp"
#include "util/Random.hpp"
#include "util/ResponseExpirationCache.hpp" #include "util/ResponseExpirationCache.hpp"
#include "util/config/ConfigDefinition.hpp" #include "util/config/ConfigDefinition.hpp"
#include "util/log/Logger.hpp" #include "util/log/Logger.hpp"
@@ -88,6 +89,8 @@ private:
std::optional<util::ResponseExpirationCache> forwardingCache_; std::optional<util::ResponseExpirationCache> forwardingCache_;
std::optional<std::string> forwardingXUserValue_; std::optional<std::string> forwardingXUserValue_;
std::unique_ptr<util::RandomGeneratorInterface> randomGenerator_;
std::vector<SourcePtr> sources_; std::vector<SourcePtr> sources_;
std::optional<etl::ETLState> etlState_; std::optional<etl::ETLState> etlState_;
std::uint32_t downloadRanges_ = std::uint32_t downloadRanges_ =
@@ -123,6 +126,7 @@ public:
* @param ioc The io_context to run on * @param ioc The io_context to run on
* @param backend BackendInterface implementation * @param backend BackendInterface implementation
* @param subscriptions Subscription manager * @param subscriptions Subscription manager
* @param randomGenerator A random generator to use for selecting sources
* @param validatedLedgers The network validated ledgers datastructure * @param validatedLedgers The network validated ledgers datastructure
* @param sourceFactory A factory function to create a source * @param sourceFactory A factory function to create a source
*/ */
@@ -131,6 +135,7 @@ public:
boost::asio::io_context& ioc, boost::asio::io_context& ioc,
std::shared_ptr<BackendInterface> backend, std::shared_ptr<BackendInterface> backend,
std::shared_ptr<feed::SubscriptionManagerInterface> subscriptions, std::shared_ptr<feed::SubscriptionManagerInterface> subscriptions,
std::unique_ptr<util::RandomGeneratorInterface> randomGenerator,
std::shared_ptr<etl::NetworkValidatedLedgersInterface> validatedLedgers, std::shared_ptr<etl::NetworkValidatedLedgersInterface> validatedLedgers,
SourceFactory sourceFactory = makeSource SourceFactory sourceFactory = makeSource
); );
@@ -142,6 +147,7 @@ public:
* @param ioc The io_context to run on * @param ioc The io_context to run on
* @param backend BackendInterface implementation * @param backend BackendInterface implementation
* @param subscriptions Subscription manager * @param subscriptions Subscription manager
* @param randomGenerator A random generator to use for selecting sources
* @param validatedLedgers The network validated ledgers datastructure * @param validatedLedgers The network validated ledgers datastructure
* @param sourceFactory A factory function to create a source * @param sourceFactory A factory function to create a source
* @return A shared pointer to a new instance of LoadBalancer * @return A shared pointer to a new instance of LoadBalancer
@@ -152,6 +158,7 @@ public:
boost::asio::io_context& ioc, boost::asio::io_context& ioc,
std::shared_ptr<BackendInterface> backend, std::shared_ptr<BackendInterface> backend,
std::shared_ptr<feed::SubscriptionManagerInterface> subscriptions, std::shared_ptr<feed::SubscriptionManagerInterface> subscriptions,
std::unique_ptr<util::RandomGeneratorInterface> randomGenerator,
std::shared_ptr<etl::NetworkValidatedLedgersInterface> validatedLedgers, std::shared_ptr<etl::NetworkValidatedLedgersInterface> validatedLedgers,
SourceFactory sourceFactory = makeSource SourceFactory sourceFactory = makeSource
); );

View File

@@ -25,12 +25,20 @@
namespace util { namespace util {
void MTRandomGenerator::MTRandomGenerator() : generator_{std::chrono::system_clock::now().time_since_epoch().count()}
Random::setSeed(size_t seed)
{ {
generator.seed(seed);
} }
std::mt19937_64 Random::generator{std::chrono::system_clock::now().time_since_epoch().count()}; size_t
MTRandomGenerator::uniform(size_t min, size_t max)
{
return uniformImpl(min, max);
}
void
MTRandomGenerator::setSeed(SeedType seed)
{
generator_.seed(seed);
}
} // namespace util } // namespace util

View File

@@ -26,10 +26,52 @@
namespace util { namespace util {
/** /**
* @brief Random number generator * @brief Random number generator interface
*/ */
class Random { class RandomGeneratorInterface {
public: public:
virtual ~RandomGeneratorInterface() = default;
using SeedType = typename std::mt19937_64::result_type;
/**
* @brief Generate a random number between min and max
*
* @param min Minimum value
* @param max Maximum value
* @return Random number between min and max
*/
[[nodiscard]]
virtual size_t
uniform(size_t min, size_t max) = 0;
/**
* @brief Set the seed for the random number generator
*
* @param seed Seed to set
*/
virtual void
setSeed(SeedType seed) = 0;
};
/**
* @brief Mersenne Twister random number generator
*/
class MTRandomGenerator : public RandomGeneratorInterface {
public:
MTRandomGenerator();
/**
* @brief Generate a random number between min and max
*
* @param min Minimum value
* @param max Maximum value
* @return Random number between min and max
*/
[[nodiscard]]
size_t
uniform(size_t min, size_t max) override;
/** /**
* @brief Generate a random number between min and max * @brief Generate a random number between min and max
* *
@@ -39,16 +81,17 @@ public:
* @return Random number between min and max * @return Random number between min and max
*/ */
template <typename T> template <typename T>
static T [[nodiscard]]
uniform(T min, T max) T
uniformImpl(T min, T max)
{ {
ASSERT(min <= max, "Min cannot be greater than max. min: {}, max: {}", min, max); ASSERT(min <= max, "Min cannot be greater than max. min: {}, max: {}", min, max);
if constexpr (std::is_floating_point_v<T>) { if constexpr (std::is_floating_point_v<T>) {
std::uniform_real_distribution<T> distribution(min, max); std::uniform_real_distribution<T> distribution(min, max);
return distribution(generator); return distribution(generator_);
} }
std::uniform_int_distribution<T> distribution(min, max); std::uniform_int_distribution<T> distribution(min, max);
return distribution(generator); return distribution(generator_);
} }
/** /**
@@ -56,11 +99,11 @@ public:
* *
* @param seed Seed to set * @param seed Seed to set
*/ */
static void void
setSeed(size_t seed); setSeed(SeedType seed) override;
private: private:
static std::mt19937_64 generator; std::mt19937_64 generator_;
}; };
} // namespace util } // namespace util

View File

@@ -0,0 +1,31 @@
//------------------------------------------------------------------------------
/*
This file is part of clio: https://github.com/XRPLF/clio
Copyright (c) 2025, the clio developers.
Permission to use, copy, modify, and distribute this software for any
purpose with or without fee is hereby granted, provided that the above
copyright notice and this permission notice appear in all copies.
THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
*/
//==============================================================================
#pragma once
#include "util/Random.hpp"
#include <gtest/gtest.h>
struct MockRandomGeneratorImpl : public util::RandomGeneratorInterface {
MOCK_METHOD(size_t, uniform, (size_t min, size_t max), (override));
MOCK_METHOD(void, setSeed, (SeedType seed), (override));
};
using MockRandomGenerator = testing::NiceMock<MockRandomGeneratorImpl>;
using StrictMockRandomGenerator = testing::StrictMock<MockRandomGeneratorImpl>;

View File

@@ -624,6 +624,7 @@ TEST_F(BackendCassandraTest, Basic)
}; };
auto generateAccountTx = [&](uint32_t ledgerSequence, auto txns) { auto generateAccountTx = [&](uint32_t ledgerSequence, auto txns) {
std::vector<AccountTransactionsData> ret; std::vector<AccountTransactionsData> ret;
util::MTRandomGenerator randomGenerator;
auto accounts = generateAccounts(ledgerSequence, 10); auto accounts = generateAccounts(ledgerSequence, 10);
uint32_t idx = 0; uint32_t idx = 0;
for (auto& [hash, txn, meta] : txns) { for (auto& [hash, txn, meta] : txns) {
@@ -632,7 +633,7 @@ TEST_F(BackendCassandraTest, Basic)
data.transactionIndex = idx; data.transactionIndex = idx;
data.txHash = hash; data.txHash = hash;
for (size_t i = 0; i < 3; ++i) { for (size_t i = 0; i < 3; ++i) {
data.accounts.insert(accounts[util::Random::uniform(0ul, accounts.size() - 1)]); data.accounts.insert(accounts[randomGenerator.uniform(0ul, accounts.size() - 1)]);
} }
++idx; ++idx;
ret.push_back(data); ret.push_back(data);

View File

@@ -24,6 +24,7 @@
#include "util/MockBackendTestFixture.hpp" #include "util/MockBackendTestFixture.hpp"
#include "util/MockNetworkValidatedLedgers.hpp" #include "util/MockNetworkValidatedLedgers.hpp"
#include "util/MockPrometheus.hpp" #include "util/MockPrometheus.hpp"
#include "util/MockRandomGenerator.hpp"
#include "util/MockSource.hpp" #include "util/MockSource.hpp"
#include "util/MockSubscriptionManager.hpp" #include "util/MockSubscriptionManager.hpp"
#include "util/NameGenerator.hpp" #include "util/NameGenerator.hpp"
@@ -123,17 +124,23 @@ struct LoadBalancerConstructorTests : util::prometheus::WithPrometheus, MockBack
makeLoadBalancer() makeLoadBalancer()
{ {
auto const cfg = getParseLoadBalancerConfig(configJson_); auto const cfg = getParseLoadBalancerConfig(configJson_);
auto randomGenerator = std::make_unique<MockRandomGenerator>();
randomGenerator_ = randomGenerator.get();
return std::make_unique<LoadBalancer>( return std::make_unique<LoadBalancer>(
cfg, cfg,
ioContext_, ioContext_,
backend_, backend_,
subscriptionManager_, subscriptionManager_,
std::move(randomGenerator),
networkManager_, networkManager_,
[this](auto&&... args) -> SourcePtr { return sourceFactory_(std::forward<decltype(args)>(args)...); } [this](auto&&... args) -> SourcePtr { return sourceFactory_(std::forward<decltype(args)>(args)...); }
); );
} }
protected: protected:
MockRandomGenerator* randomGenerator_ = nullptr;
StrictMockSubscriptionManagerSharedPtr subscriptionManager_; StrictMockSubscriptionManagerSharedPtr subscriptionManager_;
StrictMockNetworkValidatedLedgersPtr networkManager_; StrictMockNetworkValidatedLedgersPtr networkManager_;
StrictMockSourceFactory sourceFactory_{2}; StrictMockSourceFactory sourceFactory_{2};
@@ -430,11 +437,6 @@ TEST_F(LoadBalancer3SourcesTests, forwardingUpdate)
} }
struct LoadBalancerLoadInitialLedgerTests : LoadBalancerOnConnectHookTests { struct LoadBalancerLoadInitialLedgerTests : LoadBalancerOnConnectHookTests {
LoadBalancerLoadInitialLedgerTests()
{
util::Random::setSeed(0);
}
protected: protected:
uint32_t const sequence_ = 123; uint32_t const sequence_ = 123;
uint32_t const numMarkers_ = 16; uint32_t const numMarkers_ = 16;
@@ -496,7 +498,6 @@ TEST_F(LoadBalancerLoadInitialLedgerCustomNumMarkersTests, loadInitialLedger)
EXPECT_CALL(sourceFactory_.sourceAt(1), run); EXPECT_CALL(sourceFactory_.sourceAt(1), run);
auto loadBalancer = makeLoadBalancer(); auto loadBalancer = makeLoadBalancer();
util::Random::setSeed(0);
EXPECT_CALL(sourceFactory_.sourceAt(0), hasLedger(sequence_)).WillOnce(Return(true)); EXPECT_CALL(sourceFactory_.sourceAt(0), hasLedger(sequence_)).WillOnce(Return(true));
EXPECT_CALL(sourceFactory_.sourceAt(0), loadInitialLedger(sequence_, numMarkers_)).WillOnce(Return(response_)); EXPECT_CALL(sourceFactory_.sourceAt(0), loadInitialLedger(sequence_, numMarkers_)).WillOnce(Return(response_));
@@ -506,7 +507,6 @@ TEST_F(LoadBalancerLoadInitialLedgerCustomNumMarkersTests, loadInitialLedger)
struct LoadBalancerFetchLegerTests : LoadBalancerOnConnectHookTests { struct LoadBalancerFetchLegerTests : LoadBalancerOnConnectHookTests {
LoadBalancerFetchLegerTests() LoadBalancerFetchLegerTests()
{ {
util::Random::setSeed(0);
response_.second.set_validated(true); response_.second.set_validated(true);
} }
@@ -580,7 +580,6 @@ TEST_F(LoadBalancerFetchLegerTests, fetch_bothSourcesFail)
struct LoadBalancerForwardToRippledTests : LoadBalancerConstructorTests, SyncAsioContextTest { struct LoadBalancerForwardToRippledTests : LoadBalancerConstructorTests, SyncAsioContextTest {
LoadBalancerForwardToRippledTests() LoadBalancerForwardToRippledTests()
{ {
util::Random::setSeed(0);
EXPECT_CALL(sourceFactory_.sourceAt(0), forwardToRippled).WillOnce(Return(boost::json::object{})); EXPECT_CALL(sourceFactory_.sourceAt(0), forwardToRippled).WillOnce(Return(boost::json::object{}));
EXPECT_CALL(sourceFactory_.sourceAt(0), run); EXPECT_CALL(sourceFactory_.sourceAt(0), run);
EXPECT_CALL(sourceFactory_.sourceAt(1), forwardToRippled).WillOnce(Return(boost::json::object{})); EXPECT_CALL(sourceFactory_.sourceAt(1), forwardToRippled).WillOnce(Return(boost::json::object{}));
@@ -814,6 +813,7 @@ TEST_F(LoadBalancerForwardToRippledTests, onLedgerClosedHookInvalidatesCache)
auto const request = boost::json::object{{"command", "server_info"}}; auto const request = boost::json::object{{"command", "server_info"}};
EXPECT_CALL(*randomGenerator_, uniform(0, 1)).WillOnce(Return(0)).WillOnce(Return(1));
EXPECT_CALL( EXPECT_CALL(
sourceFactory_.sourceAt(0), sourceFactory_.sourceAt(0),
forwardToRippled(request, clientIP_, LoadBalancer::kUSER_FORWARDING_X_USER_VALUE, testing::_) forwardToRippled(request, clientIP_, LoadBalancer::kUSER_FORWARDING_X_USER_VALUE, testing::_)

View File

@@ -26,6 +26,7 @@
#include "util/MockBackendTestFixture.hpp" #include "util/MockBackendTestFixture.hpp"
#include "util/MockNetworkValidatedLedgers.hpp" #include "util/MockNetworkValidatedLedgers.hpp"
#include "util/MockPrometheus.hpp" #include "util/MockPrometheus.hpp"
#include "util/MockRandomGenerator.hpp"
#include "util/MockSourceNg.hpp" #include "util/MockSourceNg.hpp"
#include "util/MockSubscriptionManager.hpp" #include "util/MockSubscriptionManager.hpp"
#include "util/NameGenerator.hpp" #include "util/NameGenerator.hpp"
@@ -144,17 +145,23 @@ struct LoadBalancerConstructorNgTests : util::prometheus::WithPrometheus, MockBa
makeLoadBalancer() makeLoadBalancer()
{ {
auto const cfg = getParseLoadBalancerConfig(configJson_); auto const cfg = getParseLoadBalancerConfig(configJson_);
auto randomGenerator = std::make_unique<MockRandomGenerator>();
randomGenerator_ = randomGenerator.get();
return std::make_unique<LoadBalancer>( return std::make_unique<LoadBalancer>(
cfg, cfg,
ioContext_, ioContext_,
backend_, backend_,
subscriptionManager_, subscriptionManager_,
std::move(randomGenerator),
networkManager_, networkManager_,
[this](auto&&... args) -> SourcePtr { return sourceFactory_(std::forward<decltype(args)>(args)...); } [this](auto&&... args) -> SourcePtr { return sourceFactory_(std::forward<decltype(args)>(args)...); }
); );
} }
protected: protected:
MockRandomGenerator* randomGenerator_ = nullptr;
StrictMockSubscriptionManagerSharedPtr subscriptionManager_; StrictMockSubscriptionManagerSharedPtr subscriptionManager_;
StrictMockNetworkValidatedLedgersPtr networkManager_; StrictMockNetworkValidatedLedgersPtr networkManager_;
StrictMockSourceNgFactory sourceFactory_{2}; StrictMockSourceNgFactory sourceFactory_{2};
@@ -450,11 +457,6 @@ TEST_F(LoadBalancer3SourcesNgTests, forwardingUpdate)
} }
struct LoadBalancerLoadInitialLedgerNgTests : LoadBalancerOnConnectHookNgTests { struct LoadBalancerLoadInitialLedgerNgTests : LoadBalancerOnConnectHookNgTests {
LoadBalancerLoadInitialLedgerNgTests()
{
util::Random::setSeed(0);
}
protected: protected:
uint32_t const sequence_ = 123; uint32_t const sequence_ = 123;
uint32_t const numMarkers_ = 16; uint32_t const numMarkers_ = 16;
@@ -522,7 +524,6 @@ TEST_F(LoadBalancerLoadInitialLedgerCustomNumMarkersNgTests, loadInitialLedger)
EXPECT_CALL(sourceFactory_.sourceAt(1), run); EXPECT_CALL(sourceFactory_.sourceAt(1), run);
auto loadBalancer = makeLoadBalancer(); auto loadBalancer = makeLoadBalancer();
util::Random::setSeed(0);
EXPECT_CALL(sourceFactory_.sourceAt(0), hasLedger(sequence_)).WillOnce(Return(true)); EXPECT_CALL(sourceFactory_.sourceAt(0), hasLedger(sequence_)).WillOnce(Return(true));
EXPECT_CALL(sourceFactory_.sourceAt(0), loadInitialLedger(sequence_, numMarkers_, testing::_)) EXPECT_CALL(sourceFactory_.sourceAt(0), loadInitialLedger(sequence_, numMarkers_, testing::_))
.WillOnce(Return(response_)); .WillOnce(Return(response_));
@@ -533,7 +534,6 @@ TEST_F(LoadBalancerLoadInitialLedgerCustomNumMarkersNgTests, loadInitialLedger)
struct LoadBalancerFetchLegerNgTests : LoadBalancerOnConnectHookNgTests { struct LoadBalancerFetchLegerNgTests : LoadBalancerOnConnectHookNgTests {
LoadBalancerFetchLegerNgTests() LoadBalancerFetchLegerNgTests()
{ {
util::Random::setSeed(0);
response_.second.set_validated(true); response_.second.set_validated(true);
} }
@@ -607,7 +607,6 @@ TEST_F(LoadBalancerFetchLegerNgTests, fetch_bothSourcesFail)
struct LoadBalancerForwardToRippledNgTests : LoadBalancerConstructorNgTests, SyncAsioContextTest { struct LoadBalancerForwardToRippledNgTests : LoadBalancerConstructorNgTests, SyncAsioContextTest {
LoadBalancerForwardToRippledNgTests() LoadBalancerForwardToRippledNgTests()
{ {
util::Random::setSeed(0);
EXPECT_CALL(sourceFactory_.sourceAt(0), forwardToRippled).WillOnce(Return(boost::json::object{})); EXPECT_CALL(sourceFactory_.sourceAt(0), forwardToRippled).WillOnce(Return(boost::json::object{}));
EXPECT_CALL(sourceFactory_.sourceAt(0), run); EXPECT_CALL(sourceFactory_.sourceAt(0), run);
EXPECT_CALL(sourceFactory_.sourceAt(1), forwardToRippled).WillOnce(Return(boost::json::object{})); EXPECT_CALL(sourceFactory_.sourceAt(1), forwardToRippled).WillOnce(Return(boost::json::object{}));
@@ -841,6 +840,8 @@ TEST_F(LoadBalancerForwardToRippledNgTests, onLedgerClosedHookInvalidatesCache)
auto const request = boost::json::object{{"command", "server_info"}}; auto const request = boost::json::object{{"command", "server_info"}};
EXPECT_CALL(*randomGenerator_, uniform(0, 1)).WillOnce(Return(0)).WillOnce(Return(1));
EXPECT_CALL( EXPECT_CALL(
sourceFactory_.sourceAt(0), sourceFactory_.sourceAt(0),
forwardToRippled(request, clientIP_, LoadBalancer::kUSER_FORWARDING_X_USER_VALUE, testing::_) forwardToRippled(request, clientIP_, LoadBalancer::kUSER_FORWARDING_X_USER_VALUE, testing::_)

View File

@@ -30,18 +30,21 @@ using namespace util;
struct RandomTests : public ::testing::Test { struct RandomTests : public ::testing::Test {
static std::vector<int> static std::vector<int>
generateRandoms(size_t const numRandoms = 1000) generateRandoms(MTRandomGenerator& randomGenerator, size_t const numRandoms = 1000)
{ {
std::vector<int> v; std::vector<int> v;
v.reserve(numRandoms); v.reserve(numRandoms);
std::ranges::generate_n(std::back_inserter(v), numRandoms, []() { return Random::uniform(0, 1000); }); std::ranges::generate_n(std::back_inserter(v), numRandoms, [&randomGenerator]() {
return randomGenerator.uniform(0, 1000);
});
return v; return v;
} }
}; };
TEST_F(RandomTests, Uniform) TEST_F(RandomTests, Uniform)
{ {
std::ranges::for_each(generateRandoms(), [](int const& e) { MTRandomGenerator randomGenerator;
std::ranges::for_each(generateRandoms(randomGenerator), [](int const& e) {
EXPECT_GE(e, 0); EXPECT_GE(e, 0);
EXPECT_LE(e, 1000); EXPECT_LE(e, 1000);
}); });
@@ -49,11 +52,13 @@ TEST_F(RandomTests, Uniform)
TEST_F(RandomTests, FixedSeed) TEST_F(RandomTests, FixedSeed)
{ {
Random::setSeed(42); MTRandomGenerator randomGenerator;
std::vector<int> const v1 = generateRandoms();
Random::setSeed(42); randomGenerator.setSeed(42);
std::vector<int> const v2 = generateRandoms(); std::vector<int> const v1 = generateRandoms(randomGenerator);
randomGenerator.setSeed(42);
std::vector<int> const v2 = generateRandoms(randomGenerator);
ASSERT_EQ(v1.size(), v2.size()); ASSERT_EQ(v1.size(), v2.size());
for (size_t i = 0; i < v1.size(); ++i) { for (size_t i = 0; i < v1.size(); ++i) {