mirror of
				https://github.com/XRPLF/clio.git
				synced 2025-11-04 11:55:51 +00:00 
			
		
		
		
	Compare commits
	
		
			23 Commits
		
	
	
		
			890279b15e
			...
			2.3.0-b3
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 
						 | 
					9df3e936cc | ||
| 
						 | 
					4166c46820 | ||
| 
						 | 
					f75cbd456b | ||
| 
						 | 
					d189651821 | ||
| 
						 | 
					3f791c1315 | ||
| 
						 | 
					418511332e | ||
| 
						 | 
					e5a0477352 | ||
| 
						 | 
					3118110eb8 | ||
| 
						 | 
					6d20f39f67 | ||
| 
						 | 
					9cb1e06c8e | ||
| 
						 | 
					423244eb4b | ||
| 
						 | 
					7aaba1cbad | ||
| 
						 | 
					b7c50fd73d | ||
| 
						 | 
					442ee874d5 | ||
| 
						 | 
					0679034978 | ||
| 
						 | 
					b41ea34212 | ||
| 
						 | 
					4e147deafa | ||
| 
						 | 
					b08447e8e0 | ||
| 
						 | 
					9432165ace | ||
| 
						 | 
					3b6a87249c | ||
| 
						 | 
					b7449f72b7 | ||
| 
						 | 
					443c74436e | ||
| 
						 | 
					7b5e02731d | 
@@ -42,7 +42,7 @@ verify_tag_signed() {
 | 
			
		||||
while read local_ref local_oid remote_ref remote_oid; do
 | 
			
		||||
    # Check some things if we're pushing a branch called "release/"
 | 
			
		||||
    if echo "$remote_ref" | grep ^refs\/heads\/release\/ &> /dev/null ; then
 | 
			
		||||
        version=$(echo $remote_ref | awk -F/ '{print $NF}')
 | 
			
		||||
        version=$(git tag --points-at HEAD)
 | 
			
		||||
        echo "Looks like you're trying to push a $version release..."
 | 
			
		||||
        echo "Making sure you've signed and tagged it."
 | 
			
		||||
        if verify_commit_signed && verify_tag && verify_tag_signed ; then
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										2
									
								
								.github/workflows/clang-tidy.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/workflows/clang-tidy.yml
									
									
									
									
										vendored
									
									
								
							@@ -99,7 +99,7 @@ jobs:
 | 
			
		||||
 | 
			
		||||
      - name: Create PR with fixes
 | 
			
		||||
        if: ${{ steps.run_clang_tidy.outcome != 'success' }}
 | 
			
		||||
        uses: peter-evans/create-pull-request@v6
 | 
			
		||||
        uses: peter-evans/create-pull-request@v7
 | 
			
		||||
        env:
 | 
			
		||||
          GH_REPO: ${{ github.repository }}
 | 
			
		||||
          GH_TOKEN: ${{ github.token }}
 | 
			
		||||
 
 | 
			
		||||
@@ -28,7 +28,7 @@ class Clio(ConanFile):
 | 
			
		||||
        'protobuf/3.21.9',
 | 
			
		||||
        'grpc/1.50.1',
 | 
			
		||||
        'openssl/1.1.1u',
 | 
			
		||||
        'xrpl/2.3.0-b1',
 | 
			
		||||
        'xrpl/2.3.0-b4',
 | 
			
		||||
        'libbacktrace/cci.20210118'
 | 
			
		||||
    ]
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -101,9 +101,7 @@ ClioApplication::run()
 | 
			
		||||
    auto backend = data::make_Backend(config_);
 | 
			
		||||
 | 
			
		||||
    // Manages clients subscribed to streams
 | 
			
		||||
    auto subscriptionsRunner = feed::SubscriptionManagerRunner(config_, backend);
 | 
			
		||||
 | 
			
		||||
    auto const subscriptions = subscriptionsRunner.getManager();
 | 
			
		||||
    auto subscriptions = feed::SubscriptionManager::make_SubscriptionManager(config_, backend);
 | 
			
		||||
 | 
			
		||||
    // Tracks which ledgers have been validated by the network
 | 
			
		||||
    auto ledgers = etl::NetworkValidatedLedgers::make_ValidatedLedgers();
 | 
			
		||||
 
 | 
			
		||||
@@ -124,6 +124,9 @@ struct Amendments {
 | 
			
		||||
    REGISTER(NFTokenMintOffer);
 | 
			
		||||
    REGISTER(fixReducedOffersV2);
 | 
			
		||||
    REGISTER(fixEnforceNFTokenTrustline);
 | 
			
		||||
    REGISTER(fixInnerObjTemplate2);
 | 
			
		||||
    REGISTER(fixNFTokenPageLinks);
 | 
			
		||||
    REGISTER(InvariantsV1_1);
 | 
			
		||||
 | 
			
		||||
    // Obsolete but supported by libxrpl
 | 
			
		||||
    REGISTER(CryptoConditionsSuite);
 | 
			
		||||
 
 | 
			
		||||
@@ -111,10 +111,13 @@ LoadBalancer::LoadBalancer(
 | 
			
		||||
            validatedLedgers,
 | 
			
		||||
            forwardingTimeout,
 | 
			
		||||
            [this]() {
 | 
			
		||||
                if (not hasForwardingSource_)
 | 
			
		||||
                if (not hasForwardingSource_.lock().get())
 | 
			
		||||
                    chooseForwardingSource();
 | 
			
		||||
            },
 | 
			
		||||
            [this](bool wasForwarding) {
 | 
			
		||||
                if (wasForwarding)
 | 
			
		||||
                    chooseForwardingSource();
 | 
			
		||||
            },
 | 
			
		||||
            [this]() { chooseForwardingSource(); },
 | 
			
		||||
            [this]() {
 | 
			
		||||
                if (forwardingCache_.has_value())
 | 
			
		||||
                    forwardingCache_->invalidate();
 | 
			
		||||
@@ -322,11 +325,13 @@ LoadBalancer::getETLState() noexcept
 | 
			
		||||
void
 | 
			
		||||
LoadBalancer::chooseForwardingSource()
 | 
			
		||||
{
 | 
			
		||||
    hasForwardingSource_ = false;
 | 
			
		||||
    LOG(log_.info()) << "Choosing a new source to forward subscriptions";
 | 
			
		||||
    auto hasForwardingSourceLock = hasForwardingSource_.lock();
 | 
			
		||||
    hasForwardingSourceLock.get() = false;
 | 
			
		||||
    for (auto& source : sources_) {
 | 
			
		||||
        if (not hasForwardingSource_ and source->isConnected()) {
 | 
			
		||||
        if (not hasForwardingSourceLock.get() and source->isConnected()) {
 | 
			
		||||
            source->setForwarding(true);
 | 
			
		||||
            hasForwardingSource_ = true;
 | 
			
		||||
            hasForwardingSourceLock.get() = true;
 | 
			
		||||
        } else {
 | 
			
		||||
            source->setForwarding(false);
 | 
			
		||||
        }
 | 
			
		||||
 
 | 
			
		||||
@@ -25,7 +25,7 @@
 | 
			
		||||
#include "etl/Source.hpp"
 | 
			
		||||
#include "etl/impl/ForwardingCache.hpp"
 | 
			
		||||
#include "feed/SubscriptionManagerInterface.hpp"
 | 
			
		||||
#include "rpc/Errors.hpp"
 | 
			
		||||
#include "util/Mutex.hpp"
 | 
			
		||||
#include "util/config/Config.hpp"
 | 
			
		||||
#include "util/log/Logger.hpp"
 | 
			
		||||
 | 
			
		||||
@@ -39,7 +39,6 @@
 | 
			
		||||
#include <org/xrpl/rpc/v1/ledger.pb.h>
 | 
			
		||||
#include <xrpl/proto/org/xrpl/rpc/v1/xrp_ledger.grpc.pb.h>
 | 
			
		||||
 | 
			
		||||
#include <atomic>
 | 
			
		||||
#include <chrono>
 | 
			
		||||
#include <cstdint>
 | 
			
		||||
#include <expected>
 | 
			
		||||
@@ -76,7 +75,10 @@ private:
 | 
			
		||||
    std::optional<ETLState> etlState_;
 | 
			
		||||
    std::uint32_t downloadRanges_ =
 | 
			
		||||
        DEFAULT_DOWNLOAD_RANGES; /*< The number of markers to use when downloading initial ledger */
 | 
			
		||||
    std::atomic_bool hasForwardingSource_{false};
 | 
			
		||||
 | 
			
		||||
    // Using mutext instead of atomic_bool because choosing a new source to
 | 
			
		||||
    // forward messages should be done with a mutual exclusion otherwise there will be a race condition
 | 
			
		||||
    util::Mutex<bool> hasForwardingSource_{false};
 | 
			
		||||
 | 
			
		||||
public:
 | 
			
		||||
    /**
 | 
			
		||||
 
 | 
			
		||||
@@ -53,7 +53,7 @@ namespace etl {
 | 
			
		||||
class SourceBase {
 | 
			
		||||
public:
 | 
			
		||||
    using OnConnectHook = std::function<void()>;
 | 
			
		||||
    using OnDisconnectHook = std::function<void()>;
 | 
			
		||||
    using OnDisconnectHook = std::function<void(bool)>;
 | 
			
		||||
    using OnLedgerClosedHook = std::function<void()>;
 | 
			
		||||
 | 
			
		||||
    virtual ~SourceBase() = default;
 | 
			
		||||
 
 | 
			
		||||
@@ -47,7 +47,7 @@
 | 
			
		||||
namespace etl::impl {
 | 
			
		||||
 | 
			
		||||
GrpcSource::GrpcSource(std::string const& ip, std::string const& grpcPort, std::shared_ptr<BackendInterface> backend)
 | 
			
		||||
    : log_(fmt::format("ETL_Grpc[{}:{}]", ip, grpcPort)), backend_(std::move(backend))
 | 
			
		||||
    : log_(fmt::format("GrpcSource[{}:{}]", ip, grpcPort)), backend_(std::move(backend))
 | 
			
		||||
{
 | 
			
		||||
    try {
 | 
			
		||||
        boost::asio::io_context ctx;
 | 
			
		||||
 
 | 
			
		||||
@@ -24,6 +24,8 @@
 | 
			
		||||
#include "rpc/JS.hpp"
 | 
			
		||||
#include "util/Retry.hpp"
 | 
			
		||||
#include "util/log/Logger.hpp"
 | 
			
		||||
#include "util/prometheus/Label.hpp"
 | 
			
		||||
#include "util/prometheus/Prometheus.hpp"
 | 
			
		||||
#include "util/requests/Types.hpp"
 | 
			
		||||
 | 
			
		||||
#include <boost/algorithm/string/classification.hpp>
 | 
			
		||||
@@ -66,22 +68,28 @@ SubscriptionSource::SubscriptionSource(
 | 
			
		||||
    OnConnectHook onConnect,
 | 
			
		||||
    OnDisconnectHook onDisconnect,
 | 
			
		||||
    OnLedgerClosedHook onLedgerClosed,
 | 
			
		||||
    std::chrono::steady_clock::duration const connectionTimeout,
 | 
			
		||||
    std::chrono::steady_clock::duration const wsTimeout,
 | 
			
		||||
    std::chrono::steady_clock::duration const retryDelay
 | 
			
		||||
)
 | 
			
		||||
    : log_(fmt::format("GrpcSource[{}:{}]", ip, wsPort))
 | 
			
		||||
    : log_(fmt::format("SubscriptionSource[{}:{}]", ip, wsPort))
 | 
			
		||||
    , wsConnectionBuilder_(ip, wsPort)
 | 
			
		||||
    , validatedLedgers_(std::move(validatedLedgers))
 | 
			
		||||
    , subscriptions_(std::move(subscriptions))
 | 
			
		||||
    , strand_(boost::asio::make_strand(ioContext))
 | 
			
		||||
    , wsTimeout_(wsTimeout)
 | 
			
		||||
    , retry_(util::makeRetryExponentialBackoff(retryDelay, RETRY_MAX_DELAY, strand_))
 | 
			
		||||
    , onConnect_(std::move(onConnect))
 | 
			
		||||
    , onDisconnect_(std::move(onDisconnect))
 | 
			
		||||
    , onLedgerClosed_(std::move(onLedgerClosed))
 | 
			
		||||
    , lastMessageTimeSecondsSinceEpoch_(PrometheusService::gaugeInt(
 | 
			
		||||
          "subscription_source_last_message_time",
 | 
			
		||||
          util::prometheus::Labels({{"source", fmt::format("{}:{}", ip, wsPort)}}),
 | 
			
		||||
          "Seconds since epoch of the last message received from rippled subscription streams"
 | 
			
		||||
      ))
 | 
			
		||||
{
 | 
			
		||||
    wsConnectionBuilder_.addHeader({boost::beast::http::field::user_agent, "clio-client"})
 | 
			
		||||
        .addHeader({"X-User", "clio-client"})
 | 
			
		||||
        .setConnectionTimeout(connectionTimeout);
 | 
			
		||||
        .setConnectionTimeout(wsTimeout_);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
SubscriptionSource::~SubscriptionSource()
 | 
			
		||||
@@ -133,6 +141,7 @@ void
 | 
			
		||||
SubscriptionSource::setForwarding(bool isForwarding)
 | 
			
		||||
{
 | 
			
		||||
    isForwarding_ = isForwarding;
 | 
			
		||||
    LOG(log_.info()) << "Forwarding set to " << isForwarding_;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
std::chrono::steady_clock::time_point
 | 
			
		||||
@@ -166,20 +175,22 @@ SubscriptionSource::subscribe()
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            wsConnection_ = std::move(connection).value();
 | 
			
		||||
            isConnected_ = true;
 | 
			
		||||
            onConnect_();
 | 
			
		||||
 | 
			
		||||
            auto const& subscribeCommand = getSubscribeCommandJson();
 | 
			
		||||
            auto const writeErrorOpt = wsConnection_->write(subscribeCommand, yield);
 | 
			
		||||
            auto const writeErrorOpt = wsConnection_->write(subscribeCommand, yield, wsTimeout_);
 | 
			
		||||
            if (writeErrorOpt) {
 | 
			
		||||
                handleError(writeErrorOpt.value(), yield);
 | 
			
		||||
                return;
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            isConnected_ = true;
 | 
			
		||||
            LOG(log_.info()) << "Connected";
 | 
			
		||||
            onConnect_();
 | 
			
		||||
 | 
			
		||||
            retry_.reset();
 | 
			
		||||
 | 
			
		||||
            while (!stop_) {
 | 
			
		||||
                auto const message = wsConnection_->read(yield);
 | 
			
		||||
                auto const message = wsConnection_->read(yield, wsTimeout_);
 | 
			
		||||
                if (not message) {
 | 
			
		||||
                    handleError(message.error(), yield);
 | 
			
		||||
                    return;
 | 
			
		||||
@@ -224,10 +235,11 @@ SubscriptionSource::handleMessage(std::string const& message)
 | 
			
		||||
                auto validatedLedgers = boost::json::value_to<std::string>(result.at(JS(validated_ledgers)));
 | 
			
		||||
                setValidatedRange(std::move(validatedLedgers));
 | 
			
		||||
            }
 | 
			
		||||
            LOG(log_.info()) << "Received a message on ledger subscription stream. Message : " << object;
 | 
			
		||||
            LOG(log_.debug()) << "Received a message on ledger subscription stream. Message: " << object;
 | 
			
		||||
 | 
			
		||||
        } else if (object.contains(JS(type)) && object.at(JS(type)) == JS_LedgerClosed) {
 | 
			
		||||
            LOG(log_.info()) << "Received a message on ledger subscription stream. Message : " << object;
 | 
			
		||||
            LOG(log_.debug()) << "Received a message of type 'ledgerClosed' on ledger subscription stream. Message: "
 | 
			
		||||
                              << object;
 | 
			
		||||
            if (object.contains(JS(ledger_index))) {
 | 
			
		||||
                ledgerIndex = object.at(JS(ledger_index)).as_int64();
 | 
			
		||||
            }
 | 
			
		||||
@@ -245,10 +257,13 @@ SubscriptionSource::handleMessage(std::string const& message)
 | 
			
		||||
                // 2 - Validated transaction
 | 
			
		||||
                // Only forward proposed transaction, validated transactions are sent by Clio itself
 | 
			
		||||
                if (object.contains(JS(transaction)) and !object.contains(JS(meta))) {
 | 
			
		||||
                    LOG(log_.debug()) << "Forwarding proposed transaction: " << object;
 | 
			
		||||
                    subscriptions_->forwardProposedTransaction(object);
 | 
			
		||||
                } else if (object.contains(JS(type)) && object.at(JS(type)) == JS_ValidationReceived) {
 | 
			
		||||
                    LOG(log_.debug()) << "Forwarding validation: " << object;
 | 
			
		||||
                    subscriptions_->forwardValidation(object);
 | 
			
		||||
                } else if (object.contains(JS(type)) && object.at(JS(type)) == JS_ManifestReceived) {
 | 
			
		||||
                    LOG(log_.debug()) << "Forwarding manifest: " << object;
 | 
			
		||||
                    subscriptions_->forwardManifest(object);
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
@@ -261,7 +276,7 @@ SubscriptionSource::handleMessage(std::string const& message)
 | 
			
		||||
 | 
			
		||||
        return std::nullopt;
 | 
			
		||||
    } catch (std::exception const& e) {
 | 
			
		||||
        LOG(log_.error()) << "Exception in handleMessage : " << e.what();
 | 
			
		||||
        LOG(log_.error()) << "Exception in handleMessage: " << e.what();
 | 
			
		||||
        return util::requests::RequestError{fmt::format("Error handling message: {}", e.what())};
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
@@ -270,16 +285,14 @@ void
 | 
			
		||||
SubscriptionSource::handleError(util::requests::RequestError const& error, boost::asio::yield_context yield)
 | 
			
		||||
{
 | 
			
		||||
    isConnected_ = false;
 | 
			
		||||
    isForwarding_ = false;
 | 
			
		||||
    bool const wasForwarding = isForwarding_.exchange(false);
 | 
			
		||||
    if (not stop_) {
 | 
			
		||||
        onDisconnect_();
 | 
			
		||||
        LOG(log_.info()) << "Disconnected";
 | 
			
		||||
        onDisconnect_(wasForwarding);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    if (wsConnection_ != nullptr) {
 | 
			
		||||
        auto const err = wsConnection_->close(yield);
 | 
			
		||||
        if (err) {
 | 
			
		||||
            LOG(log_.error()) << "Error closing websocket connection: " << err->message();
 | 
			
		||||
        }
 | 
			
		||||
        wsConnection_->close(yield);
 | 
			
		||||
        wsConnection_.reset();
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
@@ -306,7 +319,11 @@ SubscriptionSource::logError(util::requests::RequestError const& error) const
 | 
			
		||||
void
 | 
			
		||||
SubscriptionSource::setLastMessageTime()
 | 
			
		||||
{
 | 
			
		||||
    lastMessageTime_.lock().get() = std::chrono::steady_clock::now();
 | 
			
		||||
    lastMessageTimeSecondsSinceEpoch_.get().set(
 | 
			
		||||
        std::chrono::duration_cast<std::chrono::seconds>(std::chrono::system_clock::now().time_since_epoch()).count()
 | 
			
		||||
    );
 | 
			
		||||
    auto lock = lastMessageTime_.lock();
 | 
			
		||||
    lock.get() = std::chrono::steady_clock::now();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void
 | 
			
		||||
 
 | 
			
		||||
@@ -25,6 +25,7 @@
 | 
			
		||||
#include "util/Mutex.hpp"
 | 
			
		||||
#include "util/Retry.hpp"
 | 
			
		||||
#include "util/log/Logger.hpp"
 | 
			
		||||
#include "util/prometheus/Gauge.hpp"
 | 
			
		||||
#include "util/requests/Types.hpp"
 | 
			
		||||
#include "util/requests/WsConnection.hpp"
 | 
			
		||||
 | 
			
		||||
@@ -37,6 +38,7 @@
 | 
			
		||||
#include <atomic>
 | 
			
		||||
#include <chrono>
 | 
			
		||||
#include <cstdint>
 | 
			
		||||
#include <functional>
 | 
			
		||||
#include <future>
 | 
			
		||||
#include <memory>
 | 
			
		||||
#include <optional>
 | 
			
		||||
@@ -71,6 +73,8 @@ private:
 | 
			
		||||
 | 
			
		||||
    boost::asio::strand<boost::asio::io_context::executor_type> strand_;
 | 
			
		||||
 | 
			
		||||
    std::chrono::steady_clock::duration wsTimeout_;
 | 
			
		||||
 | 
			
		||||
    util::Retry retry_;
 | 
			
		||||
 | 
			
		||||
    OnConnectHook onConnect_;
 | 
			
		||||
@@ -83,9 +87,11 @@ private:
 | 
			
		||||
 | 
			
		||||
    util::Mutex<std::chrono::steady_clock::time_point> lastMessageTime_;
 | 
			
		||||
 | 
			
		||||
    std::reference_wrapper<util::prometheus::GaugeInt> lastMessageTimeSecondsSinceEpoch_;
 | 
			
		||||
 | 
			
		||||
    std::future<void> runFuture_;
 | 
			
		||||
 | 
			
		||||
    static constexpr std::chrono::seconds CONNECTION_TIMEOUT{30};
 | 
			
		||||
    static constexpr std::chrono::seconds WS_TIMEOUT{30};
 | 
			
		||||
    static constexpr std::chrono::seconds RETRY_MAX_DELAY{30};
 | 
			
		||||
    static constexpr std::chrono::seconds RETRY_DELAY{1};
 | 
			
		||||
 | 
			
		||||
@@ -103,7 +109,7 @@ public:
 | 
			
		||||
     * @param onDisconnect The onDisconnect hook. Called when the connection is lost
 | 
			
		||||
     * @param onLedgerClosed The onLedgerClosed hook. Called when the ledger is closed but only if the source is
 | 
			
		||||
     * forwarding
 | 
			
		||||
     * @param connectionTimeout The connection timeout. Defaults to 30 seconds
 | 
			
		||||
     * @param wsTimeout A timeout for websocket operations. Defaults to 30 seconds
 | 
			
		||||
     * @param retryDelay The retry delay. Defaults to 1 second
 | 
			
		||||
     */
 | 
			
		||||
    SubscriptionSource(
 | 
			
		||||
@@ -115,7 +121,7 @@ public:
 | 
			
		||||
        OnConnectHook onConnect,
 | 
			
		||||
        OnDisconnectHook onDisconnect,
 | 
			
		||||
        OnLedgerClosedHook onLedgerClosed,
 | 
			
		||||
        std::chrono::steady_clock::duration const connectionTimeout = CONNECTION_TIMEOUT,
 | 
			
		||||
        std::chrono::steady_clock::duration const wsTimeout = WS_TIMEOUT,
 | 
			
		||||
        std::chrono::steady_clock::duration const retryDelay = RETRY_DELAY
 | 
			
		||||
    );
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -30,6 +30,7 @@
 | 
			
		||||
#include "feed/impl/TransactionFeed.hpp"
 | 
			
		||||
#include "util/async/AnyExecutionContext.hpp"
 | 
			
		||||
#include "util/async/context/BasicExecutionContext.hpp"
 | 
			
		||||
#include "util/config/Config.hpp"
 | 
			
		||||
#include "util/log/Logger.hpp"
 | 
			
		||||
 | 
			
		||||
#include <boost/asio/executor_work_guard.hpp>
 | 
			
		||||
@@ -44,6 +45,7 @@
 | 
			
		||||
#include <cstdint>
 | 
			
		||||
#include <memory>
 | 
			
		||||
#include <string>
 | 
			
		||||
#include <utility>
 | 
			
		||||
#include <vector>
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
@@ -67,16 +69,36 @@ class SubscriptionManager : public SubscriptionManagerInterface {
 | 
			
		||||
    impl::ProposedTransactionFeed proposedTransactionFeed_;
 | 
			
		||||
 | 
			
		||||
public:
 | 
			
		||||
    /**
 | 
			
		||||
     * @brief Factory function to create a new SubscriptionManager with a PoolExecutionContext.
 | 
			
		||||
     *
 | 
			
		||||
     * @param config The configuration to use
 | 
			
		||||
     * @param backend The backend to use
 | 
			
		||||
     * @return A shared pointer to a new instance of SubscriptionManager
 | 
			
		||||
     */
 | 
			
		||||
    static std::shared_ptr<SubscriptionManager>
 | 
			
		||||
    make_SubscriptionManager(util::Config const& config, std::shared_ptr<data::BackendInterface const> const& backend)
 | 
			
		||||
    {
 | 
			
		||||
        auto const workersNum = config.valueOr<std::uint64_t>("subscription_workers", 1);
 | 
			
		||||
 | 
			
		||||
        util::Logger const logger{"Subscriptions"};
 | 
			
		||||
        LOG(logger.info()) << "Starting subscription manager with " << workersNum << " workers";
 | 
			
		||||
 | 
			
		||||
        return std::make_shared<feed::SubscriptionManager>(util::async::PoolExecutionContext(workersNum), backend);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * @brief Construct a new Subscription Manager object
 | 
			
		||||
     *
 | 
			
		||||
     * @param executor The executor to use to publish the feeds
 | 
			
		||||
     * @param backend The backend to use
 | 
			
		||||
     */
 | 
			
		||||
    template <class ExecutorCtx>
 | 
			
		||||
    SubscriptionManager(ExecutorCtx& executor, std::shared_ptr<data::BackendInterface const> const& backend)
 | 
			
		||||
    SubscriptionManager(
 | 
			
		||||
        util::async::AnyExecutionContext&& executor,
 | 
			
		||||
        std::shared_ptr<data::BackendInterface const> const& backend
 | 
			
		||||
    )
 | 
			
		||||
        : backend_(backend)
 | 
			
		||||
        , ctx_(executor)
 | 
			
		||||
        , ctx_(std::move(executor))
 | 
			
		||||
        , manifestFeed_(ctx_, "manifest")
 | 
			
		||||
        , validationsFeed_(ctx_, "validations")
 | 
			
		||||
        , ledgerFeed_(ctx_)
 | 
			
		||||
@@ -291,41 +313,4 @@ public:
 | 
			
		||||
    report() const final;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * @brief The help class to run the subscription manager. The container of PoolExecutionContext which is used to publish
 | 
			
		||||
 * the feeds.
 | 
			
		||||
 */
 | 
			
		||||
class SubscriptionManagerRunner {
 | 
			
		||||
    std::uint64_t workersNum_;
 | 
			
		||||
    using ActualExecutionCtx = util::async::PoolExecutionContext;
 | 
			
		||||
    ActualExecutionCtx ctx_;
 | 
			
		||||
    std::shared_ptr<SubscriptionManager> subscriptionManager_;
 | 
			
		||||
    util::Logger logger_{"Subscriptions"};
 | 
			
		||||
 | 
			
		||||
public:
 | 
			
		||||
    /**
 | 
			
		||||
     * @brief Construct a new Subscription Manager Runner object
 | 
			
		||||
     *
 | 
			
		||||
     * @param config The configuration
 | 
			
		||||
     * @param backend The backend to use
 | 
			
		||||
     */
 | 
			
		||||
    SubscriptionManagerRunner(util::Config const& config, std::shared_ptr<data::BackendInterface> const& backend)
 | 
			
		||||
        : workersNum_(config.valueOr<std::uint64_t>("subscription_workers", 1))
 | 
			
		||||
        , ctx_(workersNum_)
 | 
			
		||||
        , subscriptionManager_(std::make_shared<SubscriptionManager>(ctx_, backend))
 | 
			
		||||
    {
 | 
			
		||||
        LOG(logger_.info()) << "Starting subscription manager with " << workersNum_ << " workers";
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * @brief Get the subscription manager
 | 
			
		||||
     *
 | 
			
		||||
     * @return The subscription manager
 | 
			
		||||
     */
 | 
			
		||||
    std::shared_ptr<SubscriptionManager>
 | 
			
		||||
    getManager()
 | 
			
		||||
    {
 | 
			
		||||
        return subscriptionManager_;
 | 
			
		||||
    }
 | 
			
		||||
};
 | 
			
		||||
}  // namespace feed
 | 
			
		||||
 
 | 
			
		||||
@@ -22,6 +22,7 @@
 | 
			
		||||
#include "data/BackendInterface.hpp"
 | 
			
		||||
#include "rpc/Counters.hpp"
 | 
			
		||||
#include "rpc/Errors.hpp"
 | 
			
		||||
#include "rpc/RPCHelpers.hpp"
 | 
			
		||||
#include "rpc/WorkQueue.hpp"
 | 
			
		||||
#include "rpc/common/HandlerProvider.hpp"
 | 
			
		||||
#include "rpc/common/Types.hpp"
 | 
			
		||||
@@ -131,8 +132,13 @@ public:
 | 
			
		||||
    Result
 | 
			
		||||
    buildResponse(web::Context const& ctx)
 | 
			
		||||
    {
 | 
			
		||||
        if (forwardingProxy_.shouldForward(ctx))
 | 
			
		||||
        if (forwardingProxy_.shouldForward(ctx)) {
 | 
			
		||||
            // Disallow forwarding of the admin api, only user api is allowed for security reasons.
 | 
			
		||||
            if (isAdminCmd(ctx.method, ctx.params))
 | 
			
		||||
                return Result{Status{RippledError::rpcNO_PERMISSION}};
 | 
			
		||||
 | 
			
		||||
            return forwardingProxy_.forward(ctx);
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        if (backend_->isTooBusy()) {
 | 
			
		||||
            LOG(log_.error()) << "Database is too busy. Rejecting request";
 | 
			
		||||
 
 | 
			
		||||
@@ -36,6 +36,7 @@
 | 
			
		||||
#include <boost/json/array.hpp>
 | 
			
		||||
#include <boost/json/object.hpp>
 | 
			
		||||
#include <boost/json/parse.hpp>
 | 
			
		||||
#include <boost/json/serialize.hpp>
 | 
			
		||||
#include <boost/json/string.hpp>
 | 
			
		||||
#include <boost/json/value.hpp>
 | 
			
		||||
#include <boost/json/value_to.hpp>
 | 
			
		||||
@@ -49,6 +50,7 @@
 | 
			
		||||
#include <xrpl/basics/chrono.h>
 | 
			
		||||
#include <xrpl/basics/strHex.h>
 | 
			
		||||
#include <xrpl/beast/utility/Zero.h>
 | 
			
		||||
#include <xrpl/json/json_reader.h>
 | 
			
		||||
#include <xrpl/json/json_value.h>
 | 
			
		||||
#include <xrpl/protocol/AccountID.h>
 | 
			
		||||
#include <xrpl/protocol/Book.h>
 | 
			
		||||
@@ -1273,6 +1275,31 @@ specifiesCurrentOrClosedLedger(boost::json::object const& request)
 | 
			
		||||
    return false;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
bool
 | 
			
		||||
isAdminCmd(std::string const& method, boost::json::object const& request)
 | 
			
		||||
{
 | 
			
		||||
    if (method == JS(ledger)) {
 | 
			
		||||
        auto const requestStr = boost::json::serialize(request);
 | 
			
		||||
        Json::Value jv;
 | 
			
		||||
        Json::Reader{}.parse(requestStr, jv);
 | 
			
		||||
        // rippled considers string/non-zero int/non-empty array/ non-empty json as true.
 | 
			
		||||
        // Use rippled's API asBool to get the same result.
 | 
			
		||||
        // https://github.com/XRPLF/rippled/issues/5119
 | 
			
		||||
        auto const isFieldSet = [&jv](auto const field) { return jv.isMember(field) and jv[field].asBool(); };
 | 
			
		||||
 | 
			
		||||
        // According to doc
 | 
			
		||||
        // https://xrpl.org/docs/references/http-websocket-apis/public-api-methods/ledger-methods/ledger,
 | 
			
		||||
        // full/accounts/type are admin only, but type only works when full/accounts are set, so we don't need to check
 | 
			
		||||
        // type.
 | 
			
		||||
        if (isFieldSet(JS(full)) or isFieldSet(JS(accounts)))
 | 
			
		||||
            return true;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    if (method == JS(feature) and request.contains(JS(vetoed)))
 | 
			
		||||
        return true;
 | 
			
		||||
    return false;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
std::variant<ripple::uint256, Status>
 | 
			
		||||
getNFTID(boost::json::object const& request)
 | 
			
		||||
{
 | 
			
		||||
 
 | 
			
		||||
@@ -557,6 +557,16 @@ parseIssue(boost::json::object const& issue);
 | 
			
		||||
bool
 | 
			
		||||
specifiesCurrentOrClosedLedger(boost::json::object const& request);
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * @brief Check whether a request requires administrative privileges on rippled side.
 | 
			
		||||
 *
 | 
			
		||||
 * @param method The method name to check
 | 
			
		||||
 * @param request The request to check
 | 
			
		||||
 * @return true if the request requires ADMIN role
 | 
			
		||||
 */
 | 
			
		||||
bool
 | 
			
		||||
isAdminCmd(std::string const& method, boost::json::object const& request);
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * @brief Get the NFTID from the request
 | 
			
		||||
 *
 | 
			
		||||
 
 | 
			
		||||
@@ -200,15 +200,13 @@ CustomValidator CustomValidators::SubscribeStreamValidator =
 | 
			
		||||
            "ledger", "transactions", "transactions_proposed", "book_changes", "manifests", "validations"
 | 
			
		||||
        };
 | 
			
		||||
 | 
			
		||||
        static std::unordered_set<std::string> const reportingNotSupportStreams = {
 | 
			
		||||
            "peer_status", "consensus", "server"
 | 
			
		||||
        };
 | 
			
		||||
        static std::unordered_set<std::string> const notSupportStreams = {"peer_status", "consensus", "server"};
 | 
			
		||||
        for (auto const& v : value.as_array()) {
 | 
			
		||||
            if (!v.is_string())
 | 
			
		||||
                return Error{Status{RippledError::rpcINVALID_PARAMS, "streamNotString"}};
 | 
			
		||||
 | 
			
		||||
            if (reportingNotSupportStreams.contains(boost::json::value_to<std::string>(v)))
 | 
			
		||||
                return Error{Status{RippledError::rpcREPORTING_UNSUPPORTED}};
 | 
			
		||||
            if (notSupportStreams.contains(boost::json::value_to<std::string>(v)))
 | 
			
		||||
                return Error{Status{RippledError::rpcNOT_SUPPORTED}};
 | 
			
		||||
 | 
			
		||||
            if (not validStreams.contains(boost::json::value_to<std::string>(v)))
 | 
			
		||||
                return Error{Status{RippledError::rpcSTREAM_MALFORMED}};
 | 
			
		||||
 
 | 
			
		||||
@@ -60,10 +60,6 @@ public:
 | 
			
		||||
        if (ctx.method == "subscribe" || ctx.method == "unsubscribe")
 | 
			
		||||
            return false;
 | 
			
		||||
 | 
			
		||||
        // Disallow forwarding of the admin api, only user api is allowed for security reasons.
 | 
			
		||||
        if (ctx.method == "feature" and request.contains("vetoed"))
 | 
			
		||||
            return false;
 | 
			
		||||
 | 
			
		||||
        if (handlerProvider_->isClioOnly(ctx.method))
 | 
			
		||||
            return false;
 | 
			
		||||
 | 
			
		||||
@@ -73,6 +69,9 @@ public:
 | 
			
		||||
        if (specifiesCurrentOrClosedLedger(request))
 | 
			
		||||
            return true;
 | 
			
		||||
 | 
			
		||||
        if (isForcedForward(ctx))
 | 
			
		||||
            return true;
 | 
			
		||||
 | 
			
		||||
        auto const checkAccountInfoForward = [&]() {
 | 
			
		||||
            return ctx.method == "account_info" and request.contains("queue") and request.at("queue").is_bool() and
 | 
			
		||||
                request.at("queue").as_bool();
 | 
			
		||||
@@ -142,6 +141,14 @@ private:
 | 
			
		||||
    {
 | 
			
		||||
        return handlerProvider_->contains(method) || isProxied(method);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    bool
 | 
			
		||||
    isForcedForward(web::Context const& ctx) const
 | 
			
		||||
    {
 | 
			
		||||
        static constexpr auto FORCE_FORWARD = "force_forward";
 | 
			
		||||
        return ctx.isAdmin and ctx.params.contains(FORCE_FORWARD) and ctx.params.at(FORCE_FORWARD).is_bool() and
 | 
			
		||||
            ctx.params.at(FORCE_FORWARD).as_bool();
 | 
			
		||||
    }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
}  // namespace rpc::impl
 | 
			
		||||
 
 | 
			
		||||
@@ -78,10 +78,17 @@ AccountNFTsHandler::process(AccountNFTsHandler::Input input, Context const& ctx)
 | 
			
		||||
        input.marker ? ripple::uint256{input.marker->c_str()} : ripple::keylet::nftpage_max(*accountID).key;
 | 
			
		||||
    auto const blob = sharedPtrBackend_->fetchLedgerObject(pageKey, lgrInfo.seq, ctx.yield);
 | 
			
		||||
 | 
			
		||||
    if (!blob)
 | 
			
		||||
    if (!blob) {
 | 
			
		||||
        if (input.marker.has_value())
 | 
			
		||||
            return Error{Status{RippledError::rpcINVALID_PARAMS, "Marker field does not match any valid Page ID"}};
 | 
			
		||||
        return response;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    std::optional<ripple::SLE const> page{ripple::SLE{ripple::SerialIter{blob->data(), blob->size()}, pageKey}};
 | 
			
		||||
 | 
			
		||||
    if (page->getType() != ripple::ltNFTOKEN_PAGE)
 | 
			
		||||
        return Error{Status{RippledError::rpcINVALID_PARAMS, "Marker matches Page ID from another Account"}};
 | 
			
		||||
 | 
			
		||||
    auto numPages = 0u;
 | 
			
		||||
 | 
			
		||||
    while (page) {
 | 
			
		||||
 
 | 
			
		||||
@@ -25,9 +25,12 @@ target_sources(
 | 
			
		||||
          TimeUtils.cpp
 | 
			
		||||
          TxUtils.cpp
 | 
			
		||||
          LedgerUtils.cpp
 | 
			
		||||
          newconfig/ConfigDefinition.cpp
 | 
			
		||||
          newconfig/ObjectView.cpp
 | 
			
		||||
          newconfig/Array.cpp
 | 
			
		||||
          newconfig/ArrayView.cpp
 | 
			
		||||
          newconfig/ConfigConstraints.cpp
 | 
			
		||||
          newconfig/ConfigDefinition.cpp
 | 
			
		||||
          newconfig/ConfigFileJson.cpp
 | 
			
		||||
          newconfig/ObjectView.cpp
 | 
			
		||||
          newconfig/ValueView.cpp
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -369,7 +369,7 @@ public:
 | 
			
		||||
     * @brief Block until all operations are completed
 | 
			
		||||
     */
 | 
			
		||||
    void
 | 
			
		||||
    join() noexcept
 | 
			
		||||
    join() const noexcept
 | 
			
		||||
    {
 | 
			
		||||
        context_.join();
 | 
			
		||||
    }
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										84
									
								
								src/util/newconfig/Array.cpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										84
									
								
								src/util/newconfig/Array.cpp
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,84 @@
 | 
			
		||||
//------------------------------------------------------------------------------
 | 
			
		||||
/*
 | 
			
		||||
    This file is part of clio: https://github.com/XRPLF/clio
 | 
			
		||||
    Copyright (c) 2024, the clio developers.
 | 
			
		||||
 | 
			
		||||
    Permission to use, copy, modify, and distribute this software for any
 | 
			
		||||
    purpose with or without fee is hereby granted, provided that the above
 | 
			
		||||
    copyright notice and this permission notice appear in all copies.
 | 
			
		||||
 | 
			
		||||
    THE  SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
 | 
			
		||||
    WITH  REGARD  TO  THIS  SOFTWARE  INCLUDING  ALL  IMPLIED  WARRANTIES  OF
 | 
			
		||||
    MERCHANTABILITY  AND  FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
 | 
			
		||||
    ANY  SPECIAL,  DIRECT,  INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
 | 
			
		||||
    WHATSOEVER  RESULTING  FROM  LOSS  OF USE, DATA OR PROFITS, WHETHER IN AN
 | 
			
		||||
    ACTION  OF  CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
 | 
			
		||||
    OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
 | 
			
		||||
*/
 | 
			
		||||
//==============================================================================
 | 
			
		||||
 | 
			
		||||
#include "util/newconfig/Array.hpp"
 | 
			
		||||
 | 
			
		||||
#include "util/Assert.hpp"
 | 
			
		||||
#include "util/newconfig/ConfigValue.hpp"
 | 
			
		||||
#include "util/newconfig/Error.hpp"
 | 
			
		||||
#include "util/newconfig/Types.hpp"
 | 
			
		||||
 | 
			
		||||
#include <cstddef>
 | 
			
		||||
#include <optional>
 | 
			
		||||
#include <string_view>
 | 
			
		||||
#include <utility>
 | 
			
		||||
#include <vector>
 | 
			
		||||
 | 
			
		||||
namespace util::config {
 | 
			
		||||
 | 
			
		||||
Array::Array(ConfigValue arg) : itemPattern_{std::move(arg)}
 | 
			
		||||
{
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
std::optional<Error>
 | 
			
		||||
Array::addValue(Value value, std::optional<std::string_view> key)
 | 
			
		||||
{
 | 
			
		||||
    auto const& configValPattern = itemPattern_;
 | 
			
		||||
    auto const constraint = configValPattern.getConstraint();
 | 
			
		||||
 | 
			
		||||
    auto newElem = constraint.has_value() ? ConfigValue{configValPattern.type()}.withConstraint(constraint->get())
 | 
			
		||||
                                          : ConfigValue{configValPattern.type()};
 | 
			
		||||
    if (auto const maybeError = newElem.setValue(value, key); maybeError.has_value())
 | 
			
		||||
        return maybeError;
 | 
			
		||||
    elements_.emplace_back(std::move(newElem));
 | 
			
		||||
    return std::nullopt;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
size_t
 | 
			
		||||
Array::size() const
 | 
			
		||||
{
 | 
			
		||||
    return elements_.size();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
ConfigValue const&
 | 
			
		||||
Array::at(std::size_t idx) const
 | 
			
		||||
{
 | 
			
		||||
    ASSERT(idx < elements_.size(), "Index is out of scope");
 | 
			
		||||
    return elements_[idx];
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
ConfigValue const&
 | 
			
		||||
Array::getArrayPattern() const
 | 
			
		||||
{
 | 
			
		||||
    return itemPattern_;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
std::vector<ConfigValue>::const_iterator
 | 
			
		||||
Array::begin() const
 | 
			
		||||
{
 | 
			
		||||
    return elements_.begin();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
std::vector<ConfigValue>::const_iterator
 | 
			
		||||
Array::end() const
 | 
			
		||||
{
 | 
			
		||||
    return elements_.end();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
}  // namespace util::config
 | 
			
		||||
@@ -19,47 +19,42 @@
 | 
			
		||||
 | 
			
		||||
#pragma once
 | 
			
		||||
 | 
			
		||||
#include "util/Assert.hpp"
 | 
			
		||||
#include "util/newconfig/ConfigValue.hpp"
 | 
			
		||||
#include "util/newconfig/ObjectView.hpp"
 | 
			
		||||
#include "util/newconfig/ValueView.hpp"
 | 
			
		||||
#include "util/newconfig/Error.hpp"
 | 
			
		||||
#include "util/newconfig/Types.hpp"
 | 
			
		||||
 | 
			
		||||
#include <cstddef>
 | 
			
		||||
#include <iterator>
 | 
			
		||||
#include <type_traits>
 | 
			
		||||
#include <utility>
 | 
			
		||||
#include <optional>
 | 
			
		||||
#include <string_view>
 | 
			
		||||
#include <vector>
 | 
			
		||||
 | 
			
		||||
namespace util::config {
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * @brief Array definition for Json/Yaml config
 | 
			
		||||
 * @brief Array definition to store multiple values provided by the user from Json/Yaml
 | 
			
		||||
 *
 | 
			
		||||
 * Used in ClioConfigDefinition to represent multiple potential values (like whitelist)
 | 
			
		||||
 * Is constructed with only 1 element which states which type/constraint must every element
 | 
			
		||||
 * In the array satisfy
 | 
			
		||||
 */
 | 
			
		||||
class Array {
 | 
			
		||||
public:
 | 
			
		||||
    /**
 | 
			
		||||
     * @brief Constructs an Array with the provided arguments
 | 
			
		||||
     * @brief Constructs an Array with provided Arg
 | 
			
		||||
     *
 | 
			
		||||
     * @tparam Args Types of the arguments
 | 
			
		||||
     * @param args Arguments to initialize the elements of the Array
 | 
			
		||||
     * @param arg Argument to set the type and constraint of ConfigValues in Array
 | 
			
		||||
     */
 | 
			
		||||
    template <typename... Args>
 | 
			
		||||
    constexpr Array(Args&&... args) : elements_{std::forward<Args>(args)...}
 | 
			
		||||
    {
 | 
			
		||||
    }
 | 
			
		||||
    Array(ConfigValue arg);
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * @brief Add ConfigValues to Array class
 | 
			
		||||
     *
 | 
			
		||||
     * @param value The ConfigValue to add
 | 
			
		||||
     * @param key optional string key to include that will show in error message
 | 
			
		||||
     * @return optional error if adding config value to array fails. nullopt otherwise
 | 
			
		||||
     */
 | 
			
		||||
    void
 | 
			
		||||
    emplaceBack(ConfigValue value)
 | 
			
		||||
    {
 | 
			
		||||
        elements_.push_back(std::move(value));
 | 
			
		||||
    }
 | 
			
		||||
    std::optional<Error>
 | 
			
		||||
    addValue(Value value, std::optional<std::string_view> key = std::nullopt);
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * @brief Returns the number of values stored in the Array
 | 
			
		||||
@@ -67,10 +62,7 @@ public:
 | 
			
		||||
     * @return Number of values stored in the Array
 | 
			
		||||
     */
 | 
			
		||||
    [[nodiscard]] size_t
 | 
			
		||||
    size() const
 | 
			
		||||
    {
 | 
			
		||||
        return elements_.size();
 | 
			
		||||
    }
 | 
			
		||||
    size() const;
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * @brief Returns the ConfigValue at the specified index
 | 
			
		||||
@@ -79,13 +71,35 @@ public:
 | 
			
		||||
     * @return ConfigValue at the specified index
 | 
			
		||||
     */
 | 
			
		||||
    [[nodiscard]] ConfigValue const&
 | 
			
		||||
    at(std::size_t idx) const
 | 
			
		||||
    {
 | 
			
		||||
        ASSERT(idx < elements_.size(), "index is out of scope");
 | 
			
		||||
        return elements_[idx];
 | 
			
		||||
    }
 | 
			
		||||
    at(std::size_t idx) const;
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * @brief Returns the ConfigValue that defines the type/constraint every
 | 
			
		||||
     * ConfigValue must follow in Array
 | 
			
		||||
     *
 | 
			
		||||
     * @return The item_pattern
 | 
			
		||||
     */
 | 
			
		||||
    [[nodiscard]] ConfigValue const&
 | 
			
		||||
    getArrayPattern() const;
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * @brief Returns an iterator to the beginning of the ConfigValue vector.
 | 
			
		||||
     *
 | 
			
		||||
     * @return A constant iterator to the beginning of the vector.
 | 
			
		||||
     */
 | 
			
		||||
    [[nodiscard]] std::vector<ConfigValue>::const_iterator
 | 
			
		||||
    begin() const;
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * @brief Returns an iterator to the end of the ConfigValue vector.
 | 
			
		||||
     *
 | 
			
		||||
     * @return A constant iterator to the end of the vector.
 | 
			
		||||
     */
 | 
			
		||||
    [[nodiscard]] std::vector<ConfigValue>::const_iterator
 | 
			
		||||
    end() const;
 | 
			
		||||
 | 
			
		||||
private:
 | 
			
		||||
    ConfigValue itemPattern_;
 | 
			
		||||
    std::vector<ConfigValue> elements_;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										103
									
								
								src/util/newconfig/ConfigConstraints.cpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										103
									
								
								src/util/newconfig/ConfigConstraints.cpp
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,103 @@
 | 
			
		||||
//------------------------------------------------------------------------------
 | 
			
		||||
/*
 | 
			
		||||
    This file is part of clio: https://github.com/XRPLF/clio
 | 
			
		||||
    Copyright (c) 2024, the clio developers.
 | 
			
		||||
 | 
			
		||||
    Permission to use, copy, modify, and distribute this software for any
 | 
			
		||||
    purpose with or without fee is hereby granted, provided that the above
 | 
			
		||||
    copyright notice and this permission notice appear in all copies.
 | 
			
		||||
 | 
			
		||||
    THE  SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
 | 
			
		||||
    WITH  REGARD  TO  THIS  SOFTWARE  INCLUDING  ALL  IMPLIED  WARRANTIES  OF
 | 
			
		||||
    MERCHANTABILITY  AND  FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
 | 
			
		||||
    ANY  SPECIAL,  DIRECT,  INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
 | 
			
		||||
    WHATSOEVER  RESULTING  FROM  LOSS  OF USE, DATA OR PROFITS, WHETHER IN AN
 | 
			
		||||
    ACTION  OF  CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
 | 
			
		||||
    OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
 | 
			
		||||
*/
 | 
			
		||||
//==============================================================================
 | 
			
		||||
 | 
			
		||||
#include "util/newconfig/ConfigConstraints.hpp"
 | 
			
		||||
 | 
			
		||||
#include "util/newconfig/Error.hpp"
 | 
			
		||||
#include "util/newconfig/Types.hpp"
 | 
			
		||||
 | 
			
		||||
#include <cstdint>
 | 
			
		||||
#include <optional>
 | 
			
		||||
#include <regex>
 | 
			
		||||
#include <stdexcept>
 | 
			
		||||
#include <string>
 | 
			
		||||
#include <variant>
 | 
			
		||||
 | 
			
		||||
namespace util::config {
 | 
			
		||||
 | 
			
		||||
std::optional<Error>
 | 
			
		||||
PortConstraint::checkTypeImpl(Value const& port) const
 | 
			
		||||
{
 | 
			
		||||
    if (!(std::holds_alternative<int64_t>(port) || std::holds_alternative<std::string>(port)))
 | 
			
		||||
        return Error{"Port must be a string or integer"};
 | 
			
		||||
    return std::nullopt;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
std::optional<Error>
 | 
			
		||||
PortConstraint::checkValueImpl(Value const& port) const
 | 
			
		||||
{
 | 
			
		||||
    uint32_t p = 0;
 | 
			
		||||
    if (std::holds_alternative<std::string>(port)) {
 | 
			
		||||
        try {
 | 
			
		||||
            p = static_cast<uint32_t>(std::stoi(std::get<std::string>(port)));
 | 
			
		||||
        } catch (std::invalid_argument const& e) {
 | 
			
		||||
            return Error{"Port string must be an integer."};
 | 
			
		||||
        }
 | 
			
		||||
    } else {
 | 
			
		||||
        p = static_cast<uint32_t>(std::get<int64_t>(port));
 | 
			
		||||
    }
 | 
			
		||||
    if (p >= portMin && p <= portMax)
 | 
			
		||||
        return std::nullopt;
 | 
			
		||||
    return Error{"Port does not satisfy the constraint bounds"};
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
std::optional<Error>
 | 
			
		||||
ValidIPConstraint::checkTypeImpl(Value const& ip) const
 | 
			
		||||
{
 | 
			
		||||
    if (!std::holds_alternative<std::string>(ip))
 | 
			
		||||
        return Error{"Ip value must be a string"};
 | 
			
		||||
    return std::nullopt;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
std::optional<Error>
 | 
			
		||||
ValidIPConstraint::checkValueImpl(Value const& ip) const
 | 
			
		||||
{
 | 
			
		||||
    if (std::get<std::string>(ip) == "localhost")
 | 
			
		||||
        return std::nullopt;
 | 
			
		||||
 | 
			
		||||
    static std::regex const ipv4(
 | 
			
		||||
        R"(^((25[0-5]|2[0-4][0-9]|1[0-9]{2}|[1-9]?[0-9])\.){3}(25[0-5]|2[0-4][0-9]|1[0-9]{2}|[1-9]?[0-9])$)"
 | 
			
		||||
    );
 | 
			
		||||
 | 
			
		||||
    static std::regex const ip_url(
 | 
			
		||||
        R"(^((http|https):\/\/)?((([a-zA-Z0-9-]+\.)+[a-zA-Z]{2,6})|(((25[0-5]|2[0-4][0-9]|1[0-9]{2}|[1-9]?[0-9])\.){3}(25[0-5]|2[0-4][0-9]|1[0-9]{2}|[1-9]?[0-9])))(:\d{1,5})?(\/[^\s]*)?$)"
 | 
			
		||||
    );
 | 
			
		||||
    if (std::regex_match(std::get<std::string>(ip), ipv4) || std::regex_match(std::get<std::string>(ip), ip_url))
 | 
			
		||||
        return std::nullopt;
 | 
			
		||||
 | 
			
		||||
    return Error{"Ip is not a valid ip address"};
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
std::optional<Error>
 | 
			
		||||
PositiveDouble::checkTypeImpl(Value const& num) const
 | 
			
		||||
{
 | 
			
		||||
    if (!(std::holds_alternative<double>(num) || std::holds_alternative<int64_t>(num)))
 | 
			
		||||
        return Error{"Double number must be of type int or double"};
 | 
			
		||||
    return std::nullopt;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
std::optional<Error>
 | 
			
		||||
PositiveDouble::checkValueImpl(Value const& num) const
 | 
			
		||||
{
 | 
			
		||||
    if (std::get<double>(num) >= 0)
 | 
			
		||||
        return std::nullopt;
 | 
			
		||||
    return Error{"Double number must be greater than 0"};
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
}  // namespace util::config
 | 
			
		||||
							
								
								
									
										362
									
								
								src/util/newconfig/ConfigConstraints.hpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										362
									
								
								src/util/newconfig/ConfigConstraints.hpp
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,362 @@
 | 
			
		||||
//------------------------------------------------------------------------------
 | 
			
		||||
/*
 | 
			
		||||
   This file is part of clio: https://github.com/XRPLF/clio
 | 
			
		||||
   Copyright (c) 2024, the clio developers.
 | 
			
		||||
 | 
			
		||||
   Permission to use, copy, modify, and distribute this software for any
 | 
			
		||||
   purpose with or without fee is hereby granted, provided that the above
 | 
			
		||||
   copyright notice and this permission notice appear in all copies.
 | 
			
		||||
 | 
			
		||||
   THE  SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
 | 
			
		||||
   WITH  REGARD  TO  THIS  SOFTWARE  INCLUDING  ALL  IMPLIED  WARRANTIES  OF
 | 
			
		||||
   MERCHANTABILITY  AND  FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
 | 
			
		||||
   ANY  SPECIAL,  DIRECT,  INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
 | 
			
		||||
   WHATSOEVER  RESULTING  FROM  LOSS  OF USE, DATA OR PROFITS, WHETHER IN AN
 | 
			
		||||
   ACTION  OF  CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
 | 
			
		||||
   OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
 | 
			
		||||
*/
 | 
			
		||||
//==============================================================================
 | 
			
		||||
 | 
			
		||||
#pragma once
 | 
			
		||||
 | 
			
		||||
#include "rpc/common/APIVersion.hpp"
 | 
			
		||||
#include "util/log/Logger.hpp"
 | 
			
		||||
#include "util/newconfig/Error.hpp"
 | 
			
		||||
#include "util/newconfig/Types.hpp"
 | 
			
		||||
 | 
			
		||||
#include <fmt/core.h>
 | 
			
		||||
#include <fmt/format.h>
 | 
			
		||||
 | 
			
		||||
#include <algorithm>
 | 
			
		||||
#include <array>
 | 
			
		||||
#include <cstddef>
 | 
			
		||||
#include <cstdint>
 | 
			
		||||
#include <limits>
 | 
			
		||||
#include <optional>
 | 
			
		||||
#include <string>
 | 
			
		||||
#include <string_view>
 | 
			
		||||
#include <variant>
 | 
			
		||||
 | 
			
		||||
namespace util::config {
 | 
			
		||||
class ValueView;
 | 
			
		||||
class ConfigValue;
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * @brief specific values that are accepted for logger levels in config.
 | 
			
		||||
 */
 | 
			
		||||
static constexpr std::array<char const*, 7> LOG_LEVELS = {
 | 
			
		||||
    "trace",
 | 
			
		||||
    "debug",
 | 
			
		||||
    "info",
 | 
			
		||||
    "warning",
 | 
			
		||||
    "error",
 | 
			
		||||
    "fatal",
 | 
			
		||||
    "count",
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * @brief specific values that are accepted for logger tag style in config.
 | 
			
		||||
 */
 | 
			
		||||
static constexpr std::array<char const*, 5> LOG_TAGS = {
 | 
			
		||||
    "int",
 | 
			
		||||
    "uint",
 | 
			
		||||
    "null",
 | 
			
		||||
    "none",
 | 
			
		||||
    "uuid",
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * @brief specific values that are accepted for cache loading in config.
 | 
			
		||||
 */
 | 
			
		||||
static constexpr std::array<char const*, 3> LOAD_CACHE_MODE = {
 | 
			
		||||
    "sync",
 | 
			
		||||
    "async",
 | 
			
		||||
    "none",
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * @brief specific values that are accepted for database type in config.
 | 
			
		||||
 */
 | 
			
		||||
static constexpr std::array<char const*, 1> DATABASE_TYPE = {"cassandra"};
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * @brief An interface to enforce constraints on certain values within ClioConfigDefinition.
 | 
			
		||||
 */
 | 
			
		||||
class Constraint {
 | 
			
		||||
public:
 | 
			
		||||
    constexpr virtual ~Constraint() noexcept = default;
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * @brief Check if the value meets the specific constraint.
 | 
			
		||||
     *
 | 
			
		||||
     * @param val The value to be checked
 | 
			
		||||
     * @return An Error object if the constraint is not met, nullopt otherwise
 | 
			
		||||
     */
 | 
			
		||||
    [[nodiscard]]
 | 
			
		||||
    std::optional<Error>
 | 
			
		||||
    checkConstraint(Value const& val) const
 | 
			
		||||
    {
 | 
			
		||||
        if (auto const maybeError = checkTypeImpl(val); maybeError.has_value())
 | 
			
		||||
            return maybeError;
 | 
			
		||||
        return checkValueImpl(val);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
protected:
 | 
			
		||||
    /**
 | 
			
		||||
     * @brief Creates an error message for all constraints that must satisfy certain hard-coded values.
 | 
			
		||||
     *
 | 
			
		||||
     * @tparam arrSize, the size of the array of hardcoded values
 | 
			
		||||
     * @param key The key to the value
 | 
			
		||||
     * @param value The value the user provided
 | 
			
		||||
     * @param arr The array with hard-coded values to add to error message
 | 
			
		||||
     * @return The error message specifying what the value of key must be
 | 
			
		||||
     */
 | 
			
		||||
    template <std::size_t arrSize>
 | 
			
		||||
    constexpr std::string
 | 
			
		||||
    makeErrorMsg(std::string_view key, Value const& value, std::array<char const*, arrSize> arr) const
 | 
			
		||||
    {
 | 
			
		||||
        // Extract the value from the variant
 | 
			
		||||
        auto const valueStr = std::visit([](auto const& v) { return fmt::format("{}", v); }, value);
 | 
			
		||||
 | 
			
		||||
        // Create the error message
 | 
			
		||||
        return fmt::format(
 | 
			
		||||
            R"(You provided value "{}". Key "{}"'s value must be one of the following: {})",
 | 
			
		||||
            valueStr,
 | 
			
		||||
            key,
 | 
			
		||||
            fmt::join(arr, ", ")
 | 
			
		||||
        );
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * @brief Check if the value is of a correct type for the constraint.
 | 
			
		||||
     *
 | 
			
		||||
     * @param val The value type to be checked
 | 
			
		||||
     * @return An Error object if the constraint is not met, nullopt otherwise
 | 
			
		||||
     */
 | 
			
		||||
    virtual std::optional<Error>
 | 
			
		||||
    checkTypeImpl(Value const& val) const = 0;
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * @brief Check if the value is within the constraint.
 | 
			
		||||
     *
 | 
			
		||||
     * @param val The value type to be checked
 | 
			
		||||
     * @return An Error object if the constraint is not met, nullopt otherwise
 | 
			
		||||
     */
 | 
			
		||||
    virtual std::optional<Error>
 | 
			
		||||
    checkValueImpl(Value const& val) const = 0;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * @brief A constraint to ensure the port number is within a valid range.
 | 
			
		||||
 */
 | 
			
		||||
class PortConstraint final : public Constraint {
 | 
			
		||||
public:
 | 
			
		||||
    constexpr ~PortConstraint() noexcept override = default;
 | 
			
		||||
 | 
			
		||||
private:
 | 
			
		||||
    /**
 | 
			
		||||
     * @brief Check if the type of the value is correct for this specific constraint.
 | 
			
		||||
     *
 | 
			
		||||
     * @param port The type to be checked
 | 
			
		||||
     * @return An Error object if the constraint is not met, nullopt otherwise
 | 
			
		||||
     */
 | 
			
		||||
    [[nodiscard]] std::optional<Error>
 | 
			
		||||
    checkTypeImpl(Value const& port) const override;
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * @brief Check if the value is within the constraint.
 | 
			
		||||
     *
 | 
			
		||||
     * @param port The value to be checked
 | 
			
		||||
     * @return An Error object if the constraint is not met, nullopt otherwise
 | 
			
		||||
     */
 | 
			
		||||
    [[nodiscard]] std::optional<Error>
 | 
			
		||||
    checkValueImpl(Value const& port) const override;
 | 
			
		||||
 | 
			
		||||
    static constexpr uint32_t portMin = 1;
 | 
			
		||||
    static constexpr uint32_t portMax = 65535;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * @brief A constraint to ensure the IP address is valid.
 | 
			
		||||
 */
 | 
			
		||||
class ValidIPConstraint final : public Constraint {
 | 
			
		||||
public:
 | 
			
		||||
    constexpr ~ValidIPConstraint() noexcept override = default;
 | 
			
		||||
 | 
			
		||||
private:
 | 
			
		||||
    /**
 | 
			
		||||
     * @brief Check if the type of the value is correct for this specific constraint.
 | 
			
		||||
     *
 | 
			
		||||
     * @param ip The type to be checked.
 | 
			
		||||
     * @return An optional Error if the constraint is not met, std::nullopt otherwise
 | 
			
		||||
     */
 | 
			
		||||
    [[nodiscard]] std::optional<Error>
 | 
			
		||||
    checkTypeImpl(Value const& ip) const override;
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * @brief Check if the value is within the constraint.
 | 
			
		||||
     *
 | 
			
		||||
     * @param ip The value to be checked
 | 
			
		||||
     * @return An Error object if the constraint is not met, nullopt otherwise
 | 
			
		||||
     */
 | 
			
		||||
    [[nodiscard]] std::optional<Error>
 | 
			
		||||
    checkValueImpl(Value const& ip) const override;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * @brief A constraint class to ensure the provided value is one of the specified values in an array.
 | 
			
		||||
 *
 | 
			
		||||
 * @tparam arrSize The size of the array containing the valid values for the constraint
 | 
			
		||||
 */
 | 
			
		||||
template <std::size_t arrSize>
 | 
			
		||||
class OneOf final : public Constraint {
 | 
			
		||||
public:
 | 
			
		||||
    /**
 | 
			
		||||
     * @brief Constructs a constraint where the value must be one of the values in the provided array.
 | 
			
		||||
     *
 | 
			
		||||
     * @param key The key of the ConfigValue that has this constraint
 | 
			
		||||
     * @param arr The value that has this constraint must be of the values in arr
 | 
			
		||||
     */
 | 
			
		||||
    constexpr OneOf(std::string_view key, std::array<char const*, arrSize> arr) : key_{key}, arr_{arr}
 | 
			
		||||
    {
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    constexpr ~OneOf() noexcept override = default;
 | 
			
		||||
 | 
			
		||||
private:
 | 
			
		||||
    /**
 | 
			
		||||
     * @brief Check if the type of the value is correct for this specific constraint.
 | 
			
		||||
     *
 | 
			
		||||
     * @param val The type to be checked
 | 
			
		||||
     * @return An Error object if the constraint is not met, nullopt otherwise
 | 
			
		||||
     */
 | 
			
		||||
    [[nodiscard]] std::optional<Error>
 | 
			
		||||
    checkTypeImpl(Value const& val) const override
 | 
			
		||||
    {
 | 
			
		||||
        if (!std::holds_alternative<std::string>(val))
 | 
			
		||||
            return Error{fmt::format(R"(Key "{}"'s value must be a string)", key_)};
 | 
			
		||||
        return std::nullopt;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * @brief Check if the value matches one of the value in the provided array
 | 
			
		||||
     *
 | 
			
		||||
     * @param val The value to check
 | 
			
		||||
     * @return An Error object if the constraint is not met, nullopt otherwise
 | 
			
		||||
     */
 | 
			
		||||
    [[nodiscard]] std::optional<Error>
 | 
			
		||||
    checkValueImpl(Value const& val) const override
 | 
			
		||||
    {
 | 
			
		||||
        namespace rg = std::ranges;
 | 
			
		||||
        auto const check = [&val](std::string_view name) { return std::get<std::string>(val) == name; };
 | 
			
		||||
        if (rg::any_of(arr_, check))
 | 
			
		||||
            return std::nullopt;
 | 
			
		||||
 | 
			
		||||
        return Error{makeErrorMsg(key_, val, arr_)};
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    std::string_view key_;
 | 
			
		||||
    std::array<char const*, arrSize> arr_;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * @brief A constraint class to ensure an integer value is between two numbers (inclusive)
 | 
			
		||||
 */
 | 
			
		||||
template <typename numType>
 | 
			
		||||
class NumberValueConstraint final : public Constraint {
 | 
			
		||||
public:
 | 
			
		||||
    /**
 | 
			
		||||
     * @brief Constructs a constraint where the number must be between min_ and max_.
 | 
			
		||||
     *
 | 
			
		||||
     * @param min the minimum number it can be to satisfy this constraint
 | 
			
		||||
     * @param max the maximum number it can be to satisfy this constraint
 | 
			
		||||
     */
 | 
			
		||||
    constexpr NumberValueConstraint(numType min, numType max) : min_{min}, max_{max}
 | 
			
		||||
    {
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    constexpr ~NumberValueConstraint() noexcept override = default;
 | 
			
		||||
 | 
			
		||||
private:
 | 
			
		||||
    /**
 | 
			
		||||
     * @brief Check if the type of the value is correct for this specific constraint.
 | 
			
		||||
     *
 | 
			
		||||
     * @param num The type to be checked
 | 
			
		||||
     * @return An Error object if the constraint is not met, nullopt otherwise
 | 
			
		||||
     */
 | 
			
		||||
    [[nodiscard]] std::optional<Error>
 | 
			
		||||
    checkTypeImpl(Value const& num) const override
 | 
			
		||||
    {
 | 
			
		||||
        if (!std::holds_alternative<int64_t>(num))
 | 
			
		||||
            return Error{"Number must be of type integer"};
 | 
			
		||||
        return std::nullopt;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * @brief Check if the number is positive.
 | 
			
		||||
     *
 | 
			
		||||
     * @param num The number to check
 | 
			
		||||
     * @return An Error object if the constraint is not met, nullopt otherwise
 | 
			
		||||
     */
 | 
			
		||||
    [[nodiscard]] std::optional<Error>
 | 
			
		||||
    checkValueImpl(Value const& num) const override
 | 
			
		||||
    {
 | 
			
		||||
        auto const numValue = std::get<int64_t>(num);
 | 
			
		||||
        if (numValue >= static_cast<int64_t>(min_) && numValue <= static_cast<int64_t>(max_))
 | 
			
		||||
            return std::nullopt;
 | 
			
		||||
        return Error{fmt::format("Number must be between {} and {}", min_, max_)};
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    numType min_;
 | 
			
		||||
    numType max_;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * @brief A constraint to ensure a double number is positive
 | 
			
		||||
 */
 | 
			
		||||
class PositiveDouble final : public Constraint {
 | 
			
		||||
public:
 | 
			
		||||
    constexpr ~PositiveDouble() noexcept override = default;
 | 
			
		||||
 | 
			
		||||
private:
 | 
			
		||||
    /**
 | 
			
		||||
     * @brief Check if the type of the value is correct for this specific constraint.
 | 
			
		||||
     *
 | 
			
		||||
     * @param num The type to be checked
 | 
			
		||||
     * @return An Error object if the constraint is not met, nullopt otherwise
 | 
			
		||||
     */
 | 
			
		||||
    [[nodiscard]] std::optional<Error>
 | 
			
		||||
    checkTypeImpl(Value const& num) const override;
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * @brief Check if the number is positive.
 | 
			
		||||
     *
 | 
			
		||||
     * @param num The number to check
 | 
			
		||||
     * @return An Error object if the constraint is not met, nullopt otherwise
 | 
			
		||||
     */
 | 
			
		||||
    [[nodiscard]] std::optional<Error>
 | 
			
		||||
    checkValueImpl(Value const& num) const override;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
static constinit PortConstraint validatePort{};
 | 
			
		||||
static constinit ValidIPConstraint validateIP{};
 | 
			
		||||
 | 
			
		||||
static constinit OneOf validateChannelName{"channel", Logger::CHANNELS};
 | 
			
		||||
static constinit OneOf validateLogLevelName{"log_level", LOG_LEVELS};
 | 
			
		||||
static constinit OneOf validateCassandraName{"database.type", DATABASE_TYPE};
 | 
			
		||||
static constinit OneOf validateLoadMode{"cache.load", LOAD_CACHE_MODE};
 | 
			
		||||
static constinit OneOf validateLogTag{"log_tag_style", LOG_TAGS};
 | 
			
		||||
 | 
			
		||||
static constinit PositiveDouble validatePositiveDouble{};
 | 
			
		||||
 | 
			
		||||
static constinit NumberValueConstraint<uint16_t> validateUint16{
 | 
			
		||||
    std::numeric_limits<uint16_t>::min(),
 | 
			
		||||
    std::numeric_limits<uint16_t>::max()
 | 
			
		||||
};
 | 
			
		||||
static constinit NumberValueConstraint<uint32_t> validateUint32{
 | 
			
		||||
    std::numeric_limits<uint32_t>::min(),
 | 
			
		||||
    std::numeric_limits<uint32_t>::max()
 | 
			
		||||
};
 | 
			
		||||
static constinit NumberValueConstraint<uint32_t> validateApiVersion{rpc::API_VERSION_MIN, rpc::API_VERSION_MAX};
 | 
			
		||||
 | 
			
		||||
}  // namespace util::config
 | 
			
		||||
@@ -20,10 +20,15 @@
 | 
			
		||||
#include "util/newconfig/ConfigDefinition.hpp"
 | 
			
		||||
 | 
			
		||||
#include "util/Assert.hpp"
 | 
			
		||||
#include "util/OverloadSet.hpp"
 | 
			
		||||
#include "util/newconfig/Array.hpp"
 | 
			
		||||
#include "util/newconfig/ArrayView.hpp"
 | 
			
		||||
#include "util/newconfig/ConfigConstraints.hpp"
 | 
			
		||||
#include "util/newconfig/ConfigFileInterface.hpp"
 | 
			
		||||
#include "util/newconfig/ConfigValue.hpp"
 | 
			
		||||
#include "util/newconfig/Error.hpp"
 | 
			
		||||
#include "util/newconfig/ObjectView.hpp"
 | 
			
		||||
#include "util/newconfig/Types.hpp"
 | 
			
		||||
#include "util/newconfig/ValueView.hpp"
 | 
			
		||||
 | 
			
		||||
#include <fmt/core.h>
 | 
			
		||||
@@ -38,6 +43,7 @@
 | 
			
		||||
#include <thread>
 | 
			
		||||
#include <utility>
 | 
			
		||||
#include <variant>
 | 
			
		||||
#include <vector>
 | 
			
		||||
 | 
			
		||||
namespace util::config {
 | 
			
		||||
/**
 | 
			
		||||
@@ -47,62 +53,76 @@ namespace util::config {
 | 
			
		||||
 * without default values must be present in the user's config file.
 | 
			
		||||
 */
 | 
			
		||||
static ClioConfigDefinition ClioConfig = ClioConfigDefinition{
 | 
			
		||||
    {{"database.type", ConfigValue{ConfigType::String}.defaultValue("cassandra")},
 | 
			
		||||
    {{"database.type", ConfigValue{ConfigType::String}.defaultValue("cassandra").withConstraint(validateCassandraName)},
 | 
			
		||||
     {"database.cassandra.contact_points", ConfigValue{ConfigType::String}.defaultValue("localhost")},
 | 
			
		||||
     {"database.cassandra.port", ConfigValue{ConfigType::Integer}},
 | 
			
		||||
     {"database.cassandra.port", ConfigValue{ConfigType::Integer}.withConstraint(validatePort)},
 | 
			
		||||
     {"database.cassandra.keyspace", ConfigValue{ConfigType::String}.defaultValue("clio")},
 | 
			
		||||
     {"database.cassandra.replication_factor", ConfigValue{ConfigType::Integer}.defaultValue(3u)},
 | 
			
		||||
     {"database.cassandra.table_prefix", ConfigValue{ConfigType::String}.defaultValue("table_prefix")},
 | 
			
		||||
     {"database.cassandra.max_write_requests_outstanding", ConfigValue{ConfigType::Integer}.defaultValue(10'000)},
 | 
			
		||||
     {"database.cassandra.max_read_requests_outstanding", ConfigValue{ConfigType::Integer}.defaultValue(100'000)},
 | 
			
		||||
     {"database.cassandra.max_write_requests_outstanding",
 | 
			
		||||
      ConfigValue{ConfigType::Integer}.defaultValue(10'000).withConstraint(validateUint32)},
 | 
			
		||||
     {"database.cassandra.max_read_requests_outstanding",
 | 
			
		||||
      ConfigValue{ConfigType::Integer}.defaultValue(100'000).withConstraint(validateUint32)},
 | 
			
		||||
     {"database.cassandra.threads",
 | 
			
		||||
      ConfigValue{ConfigType::Integer}.defaultValue(static_cast<uint32_t>(std::thread::hardware_concurrency()))},
 | 
			
		||||
     {"database.cassandra.core_connections_per_host", ConfigValue{ConfigType::Integer}.defaultValue(1)},
 | 
			
		||||
     {"database.cassandra.queue_size_io", ConfigValue{ConfigType::Integer}.optional()},
 | 
			
		||||
     {"database.cassandra.write_batch_size", ConfigValue{ConfigType::Integer}.defaultValue(20)},
 | 
			
		||||
     {"etl_source.[].ip", Array{ConfigValue{ConfigType::String}.optional()}},
 | 
			
		||||
     {"etl_source.[].ws_port", Array{ConfigValue{ConfigType::String}.optional().min(1).max(65535)}},
 | 
			
		||||
     {"etl_source.[].grpc_port", Array{ConfigValue{ConfigType::String}.optional().min(1).max(65535)}},
 | 
			
		||||
     {"forwarding.cache_timeout", ConfigValue{ConfigType::Double}.defaultValue(0.0)},
 | 
			
		||||
     {"forwarding.request_timeout", ConfigValue{ConfigType::Double}.defaultValue(10.0)},
 | 
			
		||||
      ConfigValue{ConfigType::Integer}
 | 
			
		||||
          .defaultValue(static_cast<uint32_t>(std::thread::hardware_concurrency()))
 | 
			
		||||
          .withConstraint(validateUint32)},
 | 
			
		||||
     {"database.cassandra.core_connections_per_host",
 | 
			
		||||
      ConfigValue{ConfigType::Integer}.defaultValue(1).withConstraint(validateUint16)},
 | 
			
		||||
     {"database.cassandra.queue_size_io", ConfigValue{ConfigType::Integer}.optional().withConstraint(validateUint16)},
 | 
			
		||||
     {"database.cassandra.write_batch_size",
 | 
			
		||||
      ConfigValue{ConfigType::Integer}.defaultValue(20).withConstraint(validateUint16)},
 | 
			
		||||
     {"etl_source.[].ip", Array{ConfigValue{ConfigType::String}.withConstraint(validateIP)}},
 | 
			
		||||
     {"etl_source.[].ws_port", Array{ConfigValue{ConfigType::String}.withConstraint(validatePort)}},
 | 
			
		||||
     {"etl_source.[].grpc_port", Array{ConfigValue{ConfigType::String}.withConstraint(validatePort)}},
 | 
			
		||||
     {"forwarding.cache_timeout",
 | 
			
		||||
      ConfigValue{ConfigType::Double}.defaultValue(0.0).withConstraint(validatePositiveDouble)},
 | 
			
		||||
     {"forwarding.request_timeout",
 | 
			
		||||
      ConfigValue{ConfigType::Double}.defaultValue(10.0).withConstraint(validatePositiveDouble)},
 | 
			
		||||
     {"dos_guard.whitelist.[]", Array{ConfigValue{ConfigType::String}}},
 | 
			
		||||
     {"dos_guard.max_fetches", ConfigValue{ConfigType::Integer}.defaultValue(1000'000)},
 | 
			
		||||
     {"dos_guard.max_connections", ConfigValue{ConfigType::Integer}.defaultValue(20)},
 | 
			
		||||
     {"dos_guard.max_requests", ConfigValue{ConfigType::Integer}.defaultValue(20)},
 | 
			
		||||
     {"dos_guard.sweep_interval", ConfigValue{ConfigType::Double}.defaultValue(1.0)},
 | 
			
		||||
     {"cache.peers.[].ip", Array{ConfigValue{ConfigType::String}}},
 | 
			
		||||
     {"cache.peers.[].port", Array{ConfigValue{ConfigType::String}}},
 | 
			
		||||
     {"server.ip", ConfigValue{ConfigType::String}},
 | 
			
		||||
     {"server.port", ConfigValue{ConfigType::Integer}},
 | 
			
		||||
     {"server.max_queue_size", ConfigValue{ConfigType::Integer}.defaultValue(0)},
 | 
			
		||||
     {"dos_guard.max_fetches", ConfigValue{ConfigType::Integer}.defaultValue(1000'000).withConstraint(validateUint32)},
 | 
			
		||||
     {"dos_guard.max_connections", ConfigValue{ConfigType::Integer}.defaultValue(20).withConstraint(validateUint32)},
 | 
			
		||||
     {"dos_guard.max_requests", ConfigValue{ConfigType::Integer}.defaultValue(20).withConstraint(validateUint32)},
 | 
			
		||||
     {"dos_guard.sweep_interval",
 | 
			
		||||
      ConfigValue{ConfigType::Double}.defaultValue(1.0).withConstraint(validatePositiveDouble)},
 | 
			
		||||
     {"cache.peers.[].ip", Array{ConfigValue{ConfigType::String}.withConstraint(validateIP)}},
 | 
			
		||||
     {"cache.peers.[].port", Array{ConfigValue{ConfigType::String}.withConstraint(validatePort)}},
 | 
			
		||||
     {"server.ip", ConfigValue{ConfigType::String}.withConstraint(validateIP)},
 | 
			
		||||
     {"server.port", ConfigValue{ConfigType::Integer}.withConstraint(validatePort)},
 | 
			
		||||
     {"server.workers", ConfigValue{ConfigType::Integer}.withConstraint(validateUint32)},
 | 
			
		||||
     {"server.max_queue_size", ConfigValue{ConfigType::Integer}.defaultValue(0).withConstraint(validateUint32)},
 | 
			
		||||
     {"server.local_admin", ConfigValue{ConfigType::Boolean}.optional()},
 | 
			
		||||
     {"server.admin_password", ConfigValue{ConfigType::String}.optional()},
 | 
			
		||||
     {"prometheus.enabled", ConfigValue{ConfigType::Boolean}.defaultValue(true)},
 | 
			
		||||
     {"prometheus.compress_reply", ConfigValue{ConfigType::Boolean}.defaultValue(true)},
 | 
			
		||||
     {"io_threads", ConfigValue{ConfigType::Integer}.defaultValue(2)},
 | 
			
		||||
     {"cache.num_diffs", ConfigValue{ConfigType::Integer}.defaultValue(32)},
 | 
			
		||||
     {"cache.num_markers", ConfigValue{ConfigType::Integer}.defaultValue(48)},
 | 
			
		||||
     {"cache.num_cursors_from_diff", ConfigValue{ConfigType::Integer}.defaultValue(0)},
 | 
			
		||||
     {"cache.num_cursors_from_account", ConfigValue{ConfigType::Integer}.defaultValue(0)},
 | 
			
		||||
     {"cache.page_fetch_size", ConfigValue{ConfigType::Integer}.defaultValue(512)},
 | 
			
		||||
     {"cache.load", ConfigValue{ConfigType::String}.defaultValue("async")},
 | 
			
		||||
     {"log_channels.[].channel", Array{ConfigValue{ConfigType::String}.optional()}},
 | 
			
		||||
     {"log_channels.[].log_level", Array{ConfigValue{ConfigType::String}.optional()}},
 | 
			
		||||
     {"log_level", ConfigValue{ConfigType::String}.defaultValue("info")},
 | 
			
		||||
     {"io_threads", ConfigValue{ConfigType::Integer}.defaultValue(2).withConstraint(validateUint16)},
 | 
			
		||||
     {"cache.num_diffs", ConfigValue{ConfigType::Integer}.defaultValue(32).withConstraint(validateUint16)},
 | 
			
		||||
     {"cache.num_markers", ConfigValue{ConfigType::Integer}.defaultValue(48).withConstraint(validateUint16)},
 | 
			
		||||
     {"cache.num_cursors_from_diff", ConfigValue{ConfigType::Integer}.defaultValue(0).withConstraint(validateUint16)},
 | 
			
		||||
     {"cache.num_cursors_from_account", ConfigValue{ConfigType::Integer}.defaultValue(0).withConstraint(validateUint16)
 | 
			
		||||
     },
 | 
			
		||||
     {"cache.page_fetch_size", ConfigValue{ConfigType::Integer}.defaultValue(512).withConstraint(validateUint16)},
 | 
			
		||||
     {"cache.load", ConfigValue{ConfigType::String}.defaultValue("async").withConstraint(validateLoadMode)},
 | 
			
		||||
     {"log_channels.[].channel", Array{ConfigValue{ConfigType::String}.optional().withConstraint(validateChannelName)}},
 | 
			
		||||
     {"log_channels.[].log_level",
 | 
			
		||||
      Array{ConfigValue{ConfigType::String}.optional().withConstraint(validateLogLevelName)}},
 | 
			
		||||
     {"log_level", ConfigValue{ConfigType::String}.defaultValue("info").withConstraint(validateLogLevelName)},
 | 
			
		||||
     {"log_format",
 | 
			
		||||
      ConfigValue{ConfigType::String}.defaultValue(
 | 
			
		||||
          R"(%TimeStamp% (%SourceLocation%) [%ThreadID%] %Channel%:%Severity% %Message%)"
 | 
			
		||||
      )},
 | 
			
		||||
     {"log_to_console", ConfigValue{ConfigType::Boolean}.defaultValue(false)},
 | 
			
		||||
     {"log_directory", ConfigValue{ConfigType::String}.optional()},
 | 
			
		||||
     {"log_rotation_size", ConfigValue{ConfigType::Integer}.defaultValue(2048)},
 | 
			
		||||
     {"log_directory_max_size", ConfigValue{ConfigType::Integer}.defaultValue(50 * 1024)},
 | 
			
		||||
     {"log_rotation_hour_interval", ConfigValue{ConfigType::Integer}.defaultValue(12)},
 | 
			
		||||
     {"log_tag_style", ConfigValue{ConfigType::String}.defaultValue("uint")},
 | 
			
		||||
     {"extractor_threads", ConfigValue{ConfigType::Integer}.defaultValue(2u)},
 | 
			
		||||
     {"log_rotation_size", ConfigValue{ConfigType::Integer}.defaultValue(2048u).withConstraint(validateUint32)},
 | 
			
		||||
     {"log_directory_max_size",
 | 
			
		||||
      ConfigValue{ConfigType::Integer}.defaultValue(50u * 1024u).withConstraint(validateUint32)},
 | 
			
		||||
     {"log_rotation_hour_interval", ConfigValue{ConfigType::Integer}.defaultValue(12).withConstraint(validateUint32)},
 | 
			
		||||
     {"log_tag_style", ConfigValue{ConfigType::String}.defaultValue("uint").withConstraint(validateLogTag)},
 | 
			
		||||
     {"extractor_threads", ConfigValue{ConfigType::Integer}.defaultValue(2u).withConstraint(validateUint32)},
 | 
			
		||||
     {"read_only", ConfigValue{ConfigType::Boolean}.defaultValue(false)},
 | 
			
		||||
     {"txn_threshold", ConfigValue{ConfigType::Integer}.defaultValue(0)},
 | 
			
		||||
     {"start_sequence", ConfigValue{ConfigType::String}.optional()},
 | 
			
		||||
     {"finish_sequence", ConfigValue{ConfigType::String}.optional()},
 | 
			
		||||
     {"txn_threshold", ConfigValue{ConfigType::Integer}.defaultValue(0).withConstraint(validateUint16)},
 | 
			
		||||
     {"start_sequence", ConfigValue{ConfigType::Integer}.optional().withConstraint(validateUint32)},
 | 
			
		||||
     {"finish_sequence", ConfigValue{ConfigType::Integer}.optional().withConstraint(validateUint32)},
 | 
			
		||||
     {"ssl_cert_file", ConfigValue{ConfigType::String}.optional()},
 | 
			
		||||
     {"ssl_key_file", ConfigValue{ConfigType::String}.optional()},
 | 
			
		||||
     {"api_version.min", ConfigValue{ConfigType::Integer}},
 | 
			
		||||
@@ -113,7 +133,7 @@ ClioConfigDefinition::ClioConfigDefinition(std::initializer_list<KeyValuePair> p
 | 
			
		||||
{
 | 
			
		||||
    for (auto const& [key, value] : pair) {
 | 
			
		||||
        if (key.contains("[]"))
 | 
			
		||||
            ASSERT(std::holds_alternative<Array>(value), "Value must be array if key has \"[]\"");
 | 
			
		||||
            ASSERT(std::holds_alternative<Array>(value), R"(Value must be array if key has "[]")");
 | 
			
		||||
        map_.insert({key, value});
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
@@ -206,4 +226,51 @@ ClioConfigDefinition::arraySize(std::string_view prefix) const
 | 
			
		||||
    std::unreachable();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
std::optional<std::vector<Error>>
 | 
			
		||||
ClioConfigDefinition::parse(ConfigFileInterface const& config)
 | 
			
		||||
{
 | 
			
		||||
    std::vector<Error> listOfErrors;
 | 
			
		||||
    for (auto& [key, value] : map_) {
 | 
			
		||||
        // if key doesn't exist in user config, makes sure it is marked as ".optional()" or has ".defaultValue()"" in
 | 
			
		||||
        // ClioConfigDefitinion above
 | 
			
		||||
        if (!config.containsKey(key)) {
 | 
			
		||||
            if (std::holds_alternative<ConfigValue>(value)) {
 | 
			
		||||
                if (!(std::get<ConfigValue>(value).isOptional() || std::get<ConfigValue>(value).hasValue()))
 | 
			
		||||
                    listOfErrors.emplace_back(key, "key is required in user Config");
 | 
			
		||||
            } else if (std::holds_alternative<Array>(value)) {
 | 
			
		||||
                if (!(std::get<Array>(value).getArrayPattern().isOptional()))
 | 
			
		||||
                    listOfErrors.emplace_back(key, "key is required in user Config");
 | 
			
		||||
            }
 | 
			
		||||
            continue;
 | 
			
		||||
        }
 | 
			
		||||
        ASSERT(
 | 
			
		||||
            std::holds_alternative<ConfigValue>(value) || std::holds_alternative<Array>(value),
 | 
			
		||||
            "Value must be of type ConfigValue or Array"
 | 
			
		||||
        );
 | 
			
		||||
        std::visit(
 | 
			
		||||
            util::OverloadSet{// handle the case where the config value is a single element.
 | 
			
		||||
                              // attempt to set the value from the configuration for the specified key.
 | 
			
		||||
                              [&key, &config, &listOfErrors](ConfigValue& val) {
 | 
			
		||||
                                  if (auto const maybeError = val.setValue(config.getValue(key), key);
 | 
			
		||||
                                      maybeError.has_value())
 | 
			
		||||
                                      listOfErrors.emplace_back(maybeError.value());
 | 
			
		||||
                              },
 | 
			
		||||
                              // handle the case where the config value is an array.
 | 
			
		||||
                              // iterate over each provided value in the array and attempt to set it for the key.
 | 
			
		||||
                              [&key, &config, &listOfErrors](Array& arr) {
 | 
			
		||||
                                  for (auto const& val : config.getArray(key)) {
 | 
			
		||||
                                      if (auto const maybeError = arr.addValue(val, key); maybeError.has_value())
 | 
			
		||||
                                          listOfErrors.emplace_back(maybeError.value());
 | 
			
		||||
                                  }
 | 
			
		||||
                              }
 | 
			
		||||
            },
 | 
			
		||||
            value
 | 
			
		||||
        );
 | 
			
		||||
    }
 | 
			
		||||
    if (!listOfErrors.empty())
 | 
			
		||||
        return listOfErrors;
 | 
			
		||||
 | 
			
		||||
    return std::nullopt;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
}  // namespace util::config
 | 
			
		||||
 
 | 
			
		||||
@@ -24,7 +24,7 @@
 | 
			
		||||
#include "util/newconfig/ConfigDescription.hpp"
 | 
			
		||||
#include "util/newconfig/ConfigFileInterface.hpp"
 | 
			
		||||
#include "util/newconfig/ConfigValue.hpp"
 | 
			
		||||
#include "util/newconfig/Errors.hpp"
 | 
			
		||||
#include "util/newconfig/Error.hpp"
 | 
			
		||||
#include "util/newconfig/ObjectView.hpp"
 | 
			
		||||
#include "util/newconfig/ValueView.hpp"
 | 
			
		||||
 | 
			
		||||
@@ -41,6 +41,7 @@
 | 
			
		||||
#include <unordered_map>
 | 
			
		||||
#include <utility>
 | 
			
		||||
#include <variant>
 | 
			
		||||
#include <vector>
 | 
			
		||||
 | 
			
		||||
namespace util::config {
 | 
			
		||||
 | 
			
		||||
@@ -66,12 +67,13 @@ public:
 | 
			
		||||
    /**
 | 
			
		||||
     * @brief Parses the configuration file
 | 
			
		||||
     *
 | 
			
		||||
     * Should also check that no extra configuration key/value pairs are present
 | 
			
		||||
     * Also checks that no extra configuration key/value pairs are present. Adds to list of Errors
 | 
			
		||||
     * if it does
 | 
			
		||||
     *
 | 
			
		||||
     * @param config The configuration file interface
 | 
			
		||||
     * @return An optional Error object if parsing fails
 | 
			
		||||
     * @return An optional vector of Error objects stating all the failures if parsing fails
 | 
			
		||||
     */
 | 
			
		||||
    [[nodiscard]] std::optional<Error>
 | 
			
		||||
    [[nodiscard]] std::optional<std::vector<Error>>
 | 
			
		||||
    parse(ConfigFileInterface const& config);
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
@@ -80,9 +82,9 @@ public:
 | 
			
		||||
     * Should only check for valid values, without populating
 | 
			
		||||
     *
 | 
			
		||||
     * @param config The configuration file interface
 | 
			
		||||
     * @return An optional Error object if validation fails
 | 
			
		||||
     * @return An optional vector of Error objects stating all the failures if validation fails
 | 
			
		||||
     */
 | 
			
		||||
    [[nodiscard]] std::optional<Error>
 | 
			
		||||
    [[nodiscard]] std::optional<std::vector<Error>>
 | 
			
		||||
    validate(ConfigFileInterface const& config) const;
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
 
 | 
			
		||||
@@ -90,7 +90,9 @@ private:
 | 
			
		||||
        KV{"server.ip", "IP address of the Clio HTTP server."},
 | 
			
		||||
        KV{"server.port", "Port number of the Clio HTTP server."},
 | 
			
		||||
        KV{"server.max_queue_size", "Maximum size of the server's request queue."},
 | 
			
		||||
        KV{"server.workers", "Maximum number of threads for server to run with."},
 | 
			
		||||
        KV{"server.local_admin", "Indicates if the server should run with admin privileges."},
 | 
			
		||||
        KV{"server.admin_password", "Password for Clio admin-only APIs."},
 | 
			
		||||
        KV{"prometheus.enabled", "Enable or disable Prometheus metrics."},
 | 
			
		||||
        KV{"prometheus.compress_reply", "Enable or disable compression of Prometheus responses."},
 | 
			
		||||
        KV{"io_threads", "Number of I/O threads."},
 | 
			
		||||
 
 | 
			
		||||
@@ -19,9 +19,8 @@
 | 
			
		||||
 | 
			
		||||
#pragma once
 | 
			
		||||
 | 
			
		||||
#include "util/newconfig/ConfigValue.hpp"
 | 
			
		||||
#include "util/newconfig/Types.hpp"
 | 
			
		||||
 | 
			
		||||
#include <optional>
 | 
			
		||||
#include <string_view>
 | 
			
		||||
#include <vector>
 | 
			
		||||
 | 
			
		||||
@@ -36,31 +35,33 @@ namespace util::config {
 | 
			
		||||
class ConfigFileInterface {
 | 
			
		||||
public:
 | 
			
		||||
    virtual ~ConfigFileInterface() = default;
 | 
			
		||||
    /**
 | 
			
		||||
     * @brief Parses the provided path of user clio configuration data
 | 
			
		||||
     *
 | 
			
		||||
     * @param filePath The path to the Clio Config data
 | 
			
		||||
     */
 | 
			
		||||
    virtual void
 | 
			
		||||
    parse(std::string_view filePath) = 0;
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * @brief Retrieves a configuration value.
 | 
			
		||||
     * @brief Retrieves the value of configValue.
 | 
			
		||||
     *
 | 
			
		||||
     * @param key The key of the configuration value.
 | 
			
		||||
     * @return An optional containing the configuration value if found, otherwise std::nullopt.
 | 
			
		||||
     * @param key The key of configuration.
 | 
			
		||||
     * @return the value assosiated with key.
 | 
			
		||||
     */
 | 
			
		||||
    virtual std::optional<ConfigValue>
 | 
			
		||||
    virtual Value
 | 
			
		||||
    getValue(std::string_view key) const = 0;
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * @brief Retrieves an array of configuration values.
 | 
			
		||||
     *
 | 
			
		||||
     * @param key The key of the configuration array.
 | 
			
		||||
     * @return An optional containing a vector of configuration values if found, otherwise std::nullopt.
 | 
			
		||||
     * @return A vector of configuration values if found, otherwise std::nullopt.
 | 
			
		||||
     */
 | 
			
		||||
    virtual std::optional<std::vector<ConfigValue>>
 | 
			
		||||
    virtual std::vector<Value>
 | 
			
		||||
    getArray(std::string_view key) const = 0;
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * @brief Checks if key exist in configuration file.
 | 
			
		||||
     *
 | 
			
		||||
     * @param key The key to search for.
 | 
			
		||||
     * @return true if key exists in configuration file, false otherwise.
 | 
			
		||||
     */
 | 
			
		||||
    virtual bool
 | 
			
		||||
    containsKey(std::string_view key) const = 0;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
}  // namespace util::config
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										166
									
								
								src/util/newconfig/ConfigFileJson.cpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										166
									
								
								src/util/newconfig/ConfigFileJson.cpp
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,166 @@
 | 
			
		||||
//------------------------------------------------------------------------------
 | 
			
		||||
/*
 | 
			
		||||
    This file is part of clio: https://github.com/XRPLF/clio
 | 
			
		||||
    Copyright (c) 2024, the clio developers.
 | 
			
		||||
 | 
			
		||||
    Permission to use, copy, modify, and distribute this software for any
 | 
			
		||||
    purpose with or without fee is hereby granted, provided that the above
 | 
			
		||||
    copyright notice and this permission notice appear in all copies.
 | 
			
		||||
 | 
			
		||||
    THE  SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
 | 
			
		||||
    WITH  REGARD  TO  THIS  SOFTWARE  INCLUDING  ALL  IMPLIED  WARRANTIES  OF
 | 
			
		||||
    MERCHANTABILITY  AND  FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
 | 
			
		||||
    ANY  SPECIAL,  DIRECT,  INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
 | 
			
		||||
    WHATSOEVER  RESULTING  FROM  LOSS  OF USE, DATA OR PROFITS, WHETHER IN AN
 | 
			
		||||
    ACTION  OF  CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
 | 
			
		||||
    OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
 | 
			
		||||
*/
 | 
			
		||||
//==============================================================================
 | 
			
		||||
 | 
			
		||||
#include "util/newconfig/ConfigFileJson.hpp"
 | 
			
		||||
 | 
			
		||||
#include "util/Assert.hpp"
 | 
			
		||||
#include "util/newconfig/Error.hpp"
 | 
			
		||||
#include "util/newconfig/Types.hpp"
 | 
			
		||||
 | 
			
		||||
#include <boost/filesystem/path.hpp>
 | 
			
		||||
#include <boost/json/array.hpp>
 | 
			
		||||
#include <boost/json/object.hpp>
 | 
			
		||||
#include <boost/json/parse.hpp>
 | 
			
		||||
#include <boost/json/parse_options.hpp>
 | 
			
		||||
#include <boost/json/value.hpp>
 | 
			
		||||
#include <fmt/core.h>
 | 
			
		||||
 | 
			
		||||
#include <cstddef>
 | 
			
		||||
#include <exception>
 | 
			
		||||
#include <fstream>
 | 
			
		||||
#include <ios>
 | 
			
		||||
#include <iostream>
 | 
			
		||||
#include <ostream>
 | 
			
		||||
#include <sstream>
 | 
			
		||||
#include <string>
 | 
			
		||||
#include <string_view>
 | 
			
		||||
#include <utility>
 | 
			
		||||
#include <vector>
 | 
			
		||||
 | 
			
		||||
namespace util::config {
 | 
			
		||||
 | 
			
		||||
namespace {
 | 
			
		||||
/**
 | 
			
		||||
 * @brief Extracts the value from a JSON object and converts it into the corresponding type.
 | 
			
		||||
 *
 | 
			
		||||
 * @param jsonValue The JSON value to extract.
 | 
			
		||||
 * @return A variant containing the same type corresponding to the extracted value.
 | 
			
		||||
 */
 | 
			
		||||
[[nodiscard]] Value
 | 
			
		||||
extractJsonValue(boost::json::value const& jsonValue)
 | 
			
		||||
{
 | 
			
		||||
    if (jsonValue.is_int64()) {
 | 
			
		||||
        return jsonValue.as_int64();
 | 
			
		||||
    }
 | 
			
		||||
    if (jsonValue.is_string()) {
 | 
			
		||||
        return jsonValue.as_string().c_str();
 | 
			
		||||
    }
 | 
			
		||||
    if (jsonValue.is_bool()) {
 | 
			
		||||
        return jsonValue.as_bool();
 | 
			
		||||
    }
 | 
			
		||||
    if (jsonValue.is_double()) {
 | 
			
		||||
        return jsonValue.as_double();
 | 
			
		||||
    }
 | 
			
		||||
    ASSERT(false, "Json is not of type int, string, bool or double");
 | 
			
		||||
    std::unreachable();
 | 
			
		||||
}
 | 
			
		||||
}  // namespace
 | 
			
		||||
 | 
			
		||||
ConfigFileJson::ConfigFileJson(boost::json::object jsonObj)
 | 
			
		||||
{
 | 
			
		||||
    flattenJson(jsonObj, "");
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
std::expected<ConfigFileJson, Error>
 | 
			
		||||
ConfigFileJson::make_ConfigFileJson(boost::filesystem::path configFilePath)
 | 
			
		||||
{
 | 
			
		||||
    try {
 | 
			
		||||
        if (auto const in = std::ifstream(configFilePath.string(), std::ios::in | std::ios::binary); in) {
 | 
			
		||||
            std::stringstream contents;
 | 
			
		||||
            contents << in.rdbuf();
 | 
			
		||||
            auto opts = boost::json::parse_options{};
 | 
			
		||||
            opts.allow_comments = true;
 | 
			
		||||
            auto const tempObj = boost::json::parse(contents.str(), {}, opts).as_object();
 | 
			
		||||
            return ConfigFileJson{tempObj};
 | 
			
		||||
        }
 | 
			
		||||
        return std::unexpected<Error>(
 | 
			
		||||
            Error{fmt::format("Could not open configuration file '{}'", configFilePath.string())}
 | 
			
		||||
        );
 | 
			
		||||
 | 
			
		||||
    } catch (std::exception const& e) {
 | 
			
		||||
        return std::unexpected<Error>(Error{fmt::format(
 | 
			
		||||
            "An error occurred while processing configuration file '{}': {}", configFilePath.string(), e.what()
 | 
			
		||||
        )});
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
Value
 | 
			
		||||
ConfigFileJson::getValue(std::string_view key) const
 | 
			
		||||
{
 | 
			
		||||
    auto const jsonValue = jsonObject_.at(key);
 | 
			
		||||
    auto const value = extractJsonValue(jsonValue);
 | 
			
		||||
    return value;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
std::vector<Value>
 | 
			
		||||
ConfigFileJson::getArray(std::string_view key) const
 | 
			
		||||
{
 | 
			
		||||
    ASSERT(jsonObject_.at(key).is_array(), "Key {} has value that is not an array", key);
 | 
			
		||||
 | 
			
		||||
    std::vector<Value> configValues;
 | 
			
		||||
    auto const arr = jsonObject_.at(key).as_array();
 | 
			
		||||
 | 
			
		||||
    for (auto const& item : arr) {
 | 
			
		||||
        auto const value = extractJsonValue(item);
 | 
			
		||||
        configValues.emplace_back(value);
 | 
			
		||||
    }
 | 
			
		||||
    return configValues;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
bool
 | 
			
		||||
ConfigFileJson::containsKey(std::string_view key) const
 | 
			
		||||
{
 | 
			
		||||
    return jsonObject_.contains(key);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void
 | 
			
		||||
ConfigFileJson::flattenJson(boost::json::object const& obj, std::string const& prefix)
 | 
			
		||||
{
 | 
			
		||||
    for (auto const& [key, value] : obj) {
 | 
			
		||||
        std::string const fullKey = prefix.empty() ? std::string(key) : fmt::format("{}.{}", prefix, std::string(key));
 | 
			
		||||
 | 
			
		||||
        // In ClioConfigDefinition, value must be a primitive or array
 | 
			
		||||
        if (value.is_object()) {
 | 
			
		||||
            flattenJson(value.as_object(), fullKey);
 | 
			
		||||
        } else if (value.is_array()) {
 | 
			
		||||
            auto const& arr = value.as_array();
 | 
			
		||||
            for (std::size_t i = 0; i < arr.size(); ++i) {
 | 
			
		||||
                std::string const arrayPrefix = fullKey + ".[]";
 | 
			
		||||
                if (arr[i].is_object()) {
 | 
			
		||||
                    flattenJson(arr[i].as_object(), arrayPrefix);
 | 
			
		||||
                } else {
 | 
			
		||||
                    jsonObject_[arrayPrefix] = arr;
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
        } else {
 | 
			
		||||
            // if "[]" is present in key, then value must be an array instead of primitive
 | 
			
		||||
            if (fullKey.contains(".[]") && !jsonObject_.contains(fullKey)) {
 | 
			
		||||
                boost::json::array newArray;
 | 
			
		||||
                newArray.emplace_back(value);
 | 
			
		||||
                jsonObject_[fullKey] = newArray;
 | 
			
		||||
            } else if (fullKey.contains(".[]") && jsonObject_.contains(fullKey)) {
 | 
			
		||||
                jsonObject_[fullKey].as_array().emplace_back(value);
 | 
			
		||||
            } else {
 | 
			
		||||
                jsonObject_[fullKey] = value;
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
}  // namespace util::config
 | 
			
		||||
							
								
								
									
										100
									
								
								src/util/newconfig/ConfigFileJson.hpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										100
									
								
								src/util/newconfig/ConfigFileJson.hpp
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,100 @@
 | 
			
		||||
//------------------------------------------------------------------------------
 | 
			
		||||
/*
 | 
			
		||||
    This file is part of clio: https://github.com/XRPLF/clio
 | 
			
		||||
    Copyright (c) 2024, the clio developers.
 | 
			
		||||
 | 
			
		||||
    Permission to use, copy, modify, and distribute this software for any
 | 
			
		||||
    purpose with or without fee is hereby granted, provided that the above
 | 
			
		||||
    copyright notice and this permission notice appear in all copies.
 | 
			
		||||
 | 
			
		||||
    THE  SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
 | 
			
		||||
    WITH  REGARD  TO  THIS  SOFTWARE  INCLUDING  ALL  IMPLIED  WARRANTIES  OF
 | 
			
		||||
    MERCHANTABILITY  AND  FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
 | 
			
		||||
    ANY  SPECIAL,  DIRECT,  INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
 | 
			
		||||
    WHATSOEVER  RESULTING  FROM  LOSS  OF USE, DATA OR PROFITS, WHETHER IN AN
 | 
			
		||||
    ACTION  OF  CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
 | 
			
		||||
    OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
 | 
			
		||||
*/
 | 
			
		||||
//==============================================================================
 | 
			
		||||
 | 
			
		||||
#pragma once
 | 
			
		||||
 | 
			
		||||
#include "util/newconfig/ConfigFileInterface.hpp"
 | 
			
		||||
#include "util/newconfig/Error.hpp"
 | 
			
		||||
#include "util/newconfig/Types.hpp"
 | 
			
		||||
 | 
			
		||||
#include <boost/filesystem/path.hpp>
 | 
			
		||||
#include <boost/json/object.hpp>
 | 
			
		||||
 | 
			
		||||
#include <expected>
 | 
			
		||||
#include <string>
 | 
			
		||||
#include <string_view>
 | 
			
		||||
#include <vector>
 | 
			
		||||
 | 
			
		||||
namespace util::config {
 | 
			
		||||
 | 
			
		||||
/** @brief Json representation of config */
 | 
			
		||||
class ConfigFileJson final : public ConfigFileInterface {
 | 
			
		||||
public:
 | 
			
		||||
    /**
 | 
			
		||||
     * @brief Construct a new ConfigJson object and stores the values from
 | 
			
		||||
     * user's config into a json object.
 | 
			
		||||
     *
 | 
			
		||||
     * @param jsonObj the Json object to parse; represents user's config
 | 
			
		||||
     */
 | 
			
		||||
    ConfigFileJson(boost::json::object jsonObj);
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * @brief Retrieves a configuration value by its key.
 | 
			
		||||
     *
 | 
			
		||||
     * @param key The key of the configuration value to retrieve.
 | 
			
		||||
     * @return A variant containing the same type corresponding to the extracted value.
 | 
			
		||||
     */
 | 
			
		||||
    [[nodiscard]] Value
 | 
			
		||||
    getValue(std::string_view key) const override;
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * @brief Retrieves an array of configuration values by its key.
 | 
			
		||||
     *
 | 
			
		||||
     * @param key The key of the configuration array to retrieve.
 | 
			
		||||
     * @return A vector of variants holding the config values specified by user.
 | 
			
		||||
     */
 | 
			
		||||
    [[nodiscard]] std::vector<Value>
 | 
			
		||||
    getArray(std::string_view key) const override;
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * @brief Checks if the configuration contains a specific key.
 | 
			
		||||
     *
 | 
			
		||||
     * @param key The key to check for.
 | 
			
		||||
     * @return True if the key exists, false otherwise.
 | 
			
		||||
     */
 | 
			
		||||
    [[nodiscard]] bool
 | 
			
		||||
    containsKey(std::string_view key) const override;
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * @brief Creates a new ConfigFileJson by parsing the provided JSON file and
 | 
			
		||||
     * stores the values in the object.
 | 
			
		||||
     *
 | 
			
		||||
     * @param configFilePath The path to the JSON file to be parsed.
 | 
			
		||||
     * @return A ConfigFileJson object if parsing user file is successful. Error otherwise
 | 
			
		||||
     */
 | 
			
		||||
    [[nodiscard]] static std::expected<ConfigFileJson, Error>
 | 
			
		||||
    make_ConfigFileJson(boost::filesystem::path configFilePath);
 | 
			
		||||
 | 
			
		||||
private:
 | 
			
		||||
    /**
 | 
			
		||||
     * @brief Recursive function to flatten a JSON object into the same structure as the Clio Config.
 | 
			
		||||
     *
 | 
			
		||||
     * The keys will end up having the same naming convensions in Clio Config.
 | 
			
		||||
     * Other than the keys specified in user Config file, no new keys are created.
 | 
			
		||||
     *
 | 
			
		||||
     * @param obj The JSON object to flatten.
 | 
			
		||||
     * @param prefix The prefix to use for the keys in the flattened object.
 | 
			
		||||
     */
 | 
			
		||||
    void
 | 
			
		||||
    flattenJson(boost::json::object const& obj, std::string const& prefix);
 | 
			
		||||
 | 
			
		||||
    boost::json::object jsonObject_;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
}  // namespace util::config
 | 
			
		||||
@@ -19,15 +19,31 @@
 | 
			
		||||
 | 
			
		||||
#pragma once
 | 
			
		||||
 | 
			
		||||
#include <string>
 | 
			
		||||
#include "util/newconfig/ConfigFileInterface.hpp"
 | 
			
		||||
#include "util/newconfig/Types.hpp"
 | 
			
		||||
 | 
			
		||||
#include <boost/filesystem/path.hpp>
 | 
			
		||||
 | 
			
		||||
#include <string_view>
 | 
			
		||||
#include <vector>
 | 
			
		||||
 | 
			
		||||
// TODO: implement when we support yaml
 | 
			
		||||
 | 
			
		||||
namespace util::config {
 | 
			
		||||
 | 
			
		||||
/** @brief todo: Will display the different errors when parsing config */
 | 
			
		||||
struct Error {
 | 
			
		||||
    std::string_view key;
 | 
			
		||||
    std::string_view error;
 | 
			
		||||
/** @brief Yaml representation of config */
 | 
			
		||||
class ConfigFileYaml final : public ConfigFileInterface {
 | 
			
		||||
public:
 | 
			
		||||
    ConfigFileYaml() = default;
 | 
			
		||||
 | 
			
		||||
    Value
 | 
			
		||||
    getValue(std::string_view key) const override;
 | 
			
		||||
 | 
			
		||||
    std::vector<Value>
 | 
			
		||||
    getArray(std::string_view key) const override;
 | 
			
		||||
 | 
			
		||||
    bool
 | 
			
		||||
    containsKey(std::string_view key) const override;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
}  // namespace util::config
 | 
			
		||||
@@ -20,43 +20,23 @@
 | 
			
		||||
#pragma once
 | 
			
		||||
 | 
			
		||||
#include "util/Assert.hpp"
 | 
			
		||||
#include "util/UnsupportedType.hpp"
 | 
			
		||||
#include "util/OverloadSet.hpp"
 | 
			
		||||
#include "util/newconfig/ConfigConstraints.hpp"
 | 
			
		||||
#include "util/newconfig/Error.hpp"
 | 
			
		||||
#include "util/newconfig/Types.hpp"
 | 
			
		||||
 | 
			
		||||
#include <fmt/core.h>
 | 
			
		||||
 | 
			
		||||
#include <cstddef>
 | 
			
		||||
#include <cstdint>
 | 
			
		||||
#include <functional>
 | 
			
		||||
#include <optional>
 | 
			
		||||
#include <string>
 | 
			
		||||
#include <type_traits>
 | 
			
		||||
#include <string_view>
 | 
			
		||||
#include <variant>
 | 
			
		||||
 | 
			
		||||
namespace util::config {
 | 
			
		||||
 | 
			
		||||
/** @brief Custom clio config types */
 | 
			
		||||
enum class ConfigType { Integer, String, Double, Boolean };
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * @brief Get the corresponding clio config type
 | 
			
		||||
 *
 | 
			
		||||
 * @tparam Type The type to get the corresponding ConfigType for
 | 
			
		||||
 * @return The corresponding ConfigType
 | 
			
		||||
 */
 | 
			
		||||
template <typename Type>
 | 
			
		||||
constexpr ConfigType
 | 
			
		||||
getType()
 | 
			
		||||
{
 | 
			
		||||
    if constexpr (std::is_same_v<Type, int64_t>) {
 | 
			
		||||
        return ConfigType::Integer;
 | 
			
		||||
    } else if constexpr (std::is_same_v<Type, std::string>) {
 | 
			
		||||
        return ConfigType::String;
 | 
			
		||||
    } else if constexpr (std::is_same_v<Type, double>) {
 | 
			
		||||
        return ConfigType::Double;
 | 
			
		||||
    } else if constexpr (std::is_same_v<Type, bool>) {
 | 
			
		||||
        return ConfigType::Boolean;
 | 
			
		||||
    } else {
 | 
			
		||||
        static_assert(util::Unsupported<Type>, "Wrong config type");
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * @brief Represents the config values for Json/Yaml config
 | 
			
		||||
 *
 | 
			
		||||
@@ -65,8 +45,6 @@ getType()
 | 
			
		||||
 */
 | 
			
		||||
class ConfigValue {
 | 
			
		||||
public:
 | 
			
		||||
    using Type = std::variant<int64_t, std::string, bool, double>;
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * @brief Constructor initializing with the config type
 | 
			
		||||
     *
 | 
			
		||||
@@ -83,12 +61,92 @@ public:
 | 
			
		||||
     * @return Reference to this ConfigValue
 | 
			
		||||
     */
 | 
			
		||||
    [[nodiscard]] ConfigValue&
 | 
			
		||||
    defaultValue(Type value)
 | 
			
		||||
    defaultValue(Value value)
 | 
			
		||||
    {
 | 
			
		||||
        setValue(value);
 | 
			
		||||
        auto const err = checkTypeConsistency(type_, value);
 | 
			
		||||
        ASSERT(!err.has_value(), "{}", err->error);
 | 
			
		||||
        value_ = value;
 | 
			
		||||
        return *this;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * @brief Sets the value current ConfigValue given by the User's defined value
 | 
			
		||||
     *
 | 
			
		||||
     * @param value The value to set
 | 
			
		||||
     * @param key The Config key associated with the value. Optional to include; Used for debugging message to user.
 | 
			
		||||
     * @return optional Error if user tries to set a value of wrong type or not within a constraint
 | 
			
		||||
     */
 | 
			
		||||
    [[nodiscard]] std::optional<Error>
 | 
			
		||||
    setValue(Value value, std::optional<std::string_view> key = std::nullopt)
 | 
			
		||||
    {
 | 
			
		||||
        auto err = checkTypeConsistency(type_, value);
 | 
			
		||||
        if (err.has_value()) {
 | 
			
		||||
            if (key.has_value())
 | 
			
		||||
                err->error = fmt::format("{} {}", key.value(), err->error);
 | 
			
		||||
            return err;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        if (cons_.has_value()) {
 | 
			
		||||
            auto constraintCheck = cons_->get().checkConstraint(value);
 | 
			
		||||
            if (constraintCheck.has_value()) {
 | 
			
		||||
                if (key.has_value())
 | 
			
		||||
                    constraintCheck->error = fmt::format("{} {}", key.value(), constraintCheck->error);
 | 
			
		||||
                return constraintCheck;
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
        value_ = value;
 | 
			
		||||
        return std::nullopt;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * @brief Assigns a constraint to the ConfigValue.
 | 
			
		||||
     *
 | 
			
		||||
     * This method associates a specific constraint with the ConfigValue.
 | 
			
		||||
     * If the ConfigValue already holds a value, the method will check whether
 | 
			
		||||
     * the value satisfies the given constraint. If the constraint is not satisfied,
 | 
			
		||||
     * an assertion failure will occur with a detailed error message.
 | 
			
		||||
     *
 | 
			
		||||
     * @param cons The constraint to be applied to the ConfigValue.
 | 
			
		||||
     * @return A reference to the modified ConfigValue object.
 | 
			
		||||
     */
 | 
			
		||||
    [[nodiscard]] constexpr ConfigValue&
 | 
			
		||||
    withConstraint(Constraint const& cons)
 | 
			
		||||
    {
 | 
			
		||||
        cons_ = std::reference_wrapper<Constraint const>(cons);
 | 
			
		||||
        ASSERT(cons_.has_value(), "Constraint must be defined");
 | 
			
		||||
 | 
			
		||||
        if (value_.has_value()) {
 | 
			
		||||
            auto const& temp = cons_.value().get();
 | 
			
		||||
            auto const& result = temp.checkConstraint(value_.value());
 | 
			
		||||
            if (result.has_value()) {
 | 
			
		||||
                // useful for specifying clear Error message
 | 
			
		||||
                std::string type;
 | 
			
		||||
                std::visit(
 | 
			
		||||
                    util::OverloadSet{
 | 
			
		||||
                        [&type](bool tmp) { type = fmt::format("bool {}", tmp); },
 | 
			
		||||
                        [&type](std::string const& tmp) { type = fmt::format("string {}", tmp); },
 | 
			
		||||
                        [&type](double tmp) { type = fmt::format("double {}", tmp); },
 | 
			
		||||
                        [&type](int64_t tmp) { type = fmt::format("int {}", tmp); }
 | 
			
		||||
                    },
 | 
			
		||||
                    value_.value()
 | 
			
		||||
                );
 | 
			
		||||
                ASSERT(false, "Value {} ConfigValue does not satisfy the set Constraint", type);
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
        return *this;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * @brief Retrieves the constraint associated with this ConfigValue, if any.
 | 
			
		||||
     *
 | 
			
		||||
     * @return An optional reference to the associated Constraint.
 | 
			
		||||
     */
 | 
			
		||||
    [[nodiscard]] std::optional<std::reference_wrapper<Constraint const>>
 | 
			
		||||
    getConstraint() const
 | 
			
		||||
    {
 | 
			
		||||
        return cons_;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * @brief Gets the config type
 | 
			
		||||
     *
 | 
			
		||||
@@ -100,32 +158,6 @@ public:
 | 
			
		||||
        return type_;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * @brief Sets the minimum value for the config
 | 
			
		||||
     *
 | 
			
		||||
     * @param min The minimum value
 | 
			
		||||
     * @return Reference to this ConfigValue
 | 
			
		||||
     */
 | 
			
		||||
    [[nodiscard]] constexpr ConfigValue&
 | 
			
		||||
    min(std::uint32_t min)
 | 
			
		||||
    {
 | 
			
		||||
        min_ = min;
 | 
			
		||||
        return *this;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * @brief Sets the maximum value for the config
 | 
			
		||||
     *
 | 
			
		||||
     * @param max The maximum value
 | 
			
		||||
     * @return Reference to this ConfigValue
 | 
			
		||||
     */
 | 
			
		||||
    [[nodiscard]] constexpr ConfigValue&
 | 
			
		||||
    max(std::uint32_t max)
 | 
			
		||||
    {
 | 
			
		||||
        max_ = max;
 | 
			
		||||
        return *this;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * @brief Sets the config value as optional, meaning the user doesn't have to provide the value in their config
 | 
			
		||||
     *
 | 
			
		||||
@@ -165,7 +197,7 @@ public:
 | 
			
		||||
     *
 | 
			
		||||
     * @return Config Value
 | 
			
		||||
     */
 | 
			
		||||
    [[nodiscard]] Type const&
 | 
			
		||||
    [[nodiscard]] Value const&
 | 
			
		||||
    getValue() const
 | 
			
		||||
    {
 | 
			
		||||
        return value_.value();
 | 
			
		||||
@@ -178,39 +210,28 @@ private:
 | 
			
		||||
     * @param type The config type
 | 
			
		||||
     * @param value The config value
 | 
			
		||||
     */
 | 
			
		||||
    static void
 | 
			
		||||
    checkTypeConsistency(ConfigType type, Type value)
 | 
			
		||||
    static std::optional<Error>
 | 
			
		||||
    checkTypeConsistency(ConfigType type, Value value)
 | 
			
		||||
    {
 | 
			
		||||
        if (std::holds_alternative<std::string>(value)) {
 | 
			
		||||
            ASSERT(type == ConfigType::String, "Value does not match type string");
 | 
			
		||||
        } else if (std::holds_alternative<bool>(value)) {
 | 
			
		||||
            ASSERT(type == ConfigType::Boolean, "Value does not match type boolean");
 | 
			
		||||
        } else if (std::holds_alternative<double>(value)) {
 | 
			
		||||
            ASSERT(type == ConfigType::Double, "Value does not match type double");
 | 
			
		||||
        } else if (std::holds_alternative<int64_t>(value)) {
 | 
			
		||||
            ASSERT(type == ConfigType::Integer, "Value does not match type integer");
 | 
			
		||||
        if (type == ConfigType::String && !std::holds_alternative<std::string>(value)) {
 | 
			
		||||
            return Error{"value does not match type string"};
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * @brief Sets the value for the config
 | 
			
		||||
     *
 | 
			
		||||
     * @param value The value to set
 | 
			
		||||
     * @return The value that was set
 | 
			
		||||
     */
 | 
			
		||||
    Type
 | 
			
		||||
    setValue(Type value)
 | 
			
		||||
    {
 | 
			
		||||
        checkTypeConsistency(type_, value);
 | 
			
		||||
        value_ = value;
 | 
			
		||||
        return value;
 | 
			
		||||
        if (type == ConfigType::Boolean && !std::holds_alternative<bool>(value)) {
 | 
			
		||||
            return Error{"value does not match type boolean"};
 | 
			
		||||
        }
 | 
			
		||||
        if (type == ConfigType::Double && !std::holds_alternative<double>(value)) {
 | 
			
		||||
            return Error{"value does not match type double"};
 | 
			
		||||
        }
 | 
			
		||||
        if (type == ConfigType::Integer && !std::holds_alternative<int64_t>(value)) {
 | 
			
		||||
            return Error{"value does not match type integer"};
 | 
			
		||||
        }
 | 
			
		||||
        return std::nullopt;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    ConfigType type_{};
 | 
			
		||||
    bool optional_{false};
 | 
			
		||||
    std::optional<Type> value_;
 | 
			
		||||
    std::optional<std::uint32_t> min_;
 | 
			
		||||
    std::optional<std::uint32_t> max_;
 | 
			
		||||
    std::optional<Value> value_;
 | 
			
		||||
    std::optional<std::reference_wrapper<Constraint const>> cons_;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
}  // namespace util::config
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										57
									
								
								src/util/newconfig/Error.hpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										57
									
								
								src/util/newconfig/Error.hpp
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,57 @@
 | 
			
		||||
//------------------------------------------------------------------------------
 | 
			
		||||
/*
 | 
			
		||||
    This file is part of clio: https://github.com/XRPLF/clio
 | 
			
		||||
    Copyright (c) 2024, the clio developers.
 | 
			
		||||
 | 
			
		||||
    Permission to use, copy, modify, and distribute this software for any
 | 
			
		||||
    purpose with or without fee is hereby granted, provided that the above
 | 
			
		||||
    copyright notice and this permission notice appear in all copies.
 | 
			
		||||
 | 
			
		||||
    THE  SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
 | 
			
		||||
    WITH  REGARD  TO  THIS  SOFTWARE  INCLUDING  ALL  IMPLIED  WARRANTIES  OF
 | 
			
		||||
    MERCHANTABILITY  AND  FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
 | 
			
		||||
    ANY  SPECIAL,  DIRECT,  INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
 | 
			
		||||
    WHATSOEVER  RESULTING  FROM  LOSS  OF USE, DATA OR PROFITS, WHETHER IN AN
 | 
			
		||||
    ACTION  OF  CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
 | 
			
		||||
    OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
 | 
			
		||||
*/
 | 
			
		||||
//==============================================================================
 | 
			
		||||
 | 
			
		||||
#pragma once
 | 
			
		||||
 | 
			
		||||
#include <fmt/core.h>
 | 
			
		||||
 | 
			
		||||
#include <string>
 | 
			
		||||
#include <string_view>
 | 
			
		||||
#include <utility>
 | 
			
		||||
 | 
			
		||||
namespace util::config {
 | 
			
		||||
 | 
			
		||||
/** @brief Displays the different errors when parsing user config */
 | 
			
		||||
struct Error {
 | 
			
		||||
    /**
 | 
			
		||||
     * @brief Constructs an Error with a custom error message.
 | 
			
		||||
     *
 | 
			
		||||
     * @param err the error message to display to users.
 | 
			
		||||
     */
 | 
			
		||||
    Error(std::string err) : error{std::move(err)}
 | 
			
		||||
    {
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
     * @brief Constructs an Error with a custom error message.
 | 
			
		||||
     *
 | 
			
		||||
     * @param key the key associated with the error.
 | 
			
		||||
     * @param err the error message to display to users.
 | 
			
		||||
     */
 | 
			
		||||
    Error(std::string_view key, std::string_view err)
 | 
			
		||||
        : error{
 | 
			
		||||
              fmt::format("{} {}", key, err),
 | 
			
		||||
          }
 | 
			
		||||
    {
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    std::string error;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
}  // namespace util::config
 | 
			
		||||
							
								
								
									
										60
									
								
								src/util/newconfig/Types.hpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										60
									
								
								src/util/newconfig/Types.hpp
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,60 @@
 | 
			
		||||
//------------------------------------------------------------------------------
 | 
			
		||||
/*
 | 
			
		||||
   This file is part of clio: https://github.com/XRPLF/clio
 | 
			
		||||
   Copyright (c) 2024, the clio developers.
 | 
			
		||||
 | 
			
		||||
   Permission to use, copy, modify, and distribute this software for any
 | 
			
		||||
   purpose with or without fee is hereby granted, provided that the above
 | 
			
		||||
   copyright notice and this permission notice appear in all copies.
 | 
			
		||||
 | 
			
		||||
   THE  SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
 | 
			
		||||
   WITH  REGARD  TO  THIS  SOFTWARE  INCLUDING  ALL  IMPLIED  WARRANTIES  OF
 | 
			
		||||
   MERCHANTABILITY  AND  FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
 | 
			
		||||
   ANY  SPECIAL,  DIRECT,  INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
 | 
			
		||||
   WHATSOEVER  RESULTING  FROM  LOSS  OF USE, DATA OR PROFITS, WHETHER IN AN
 | 
			
		||||
   ACTION  OF  CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
 | 
			
		||||
   OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
 | 
			
		||||
*/
 | 
			
		||||
//==============================================================================
 | 
			
		||||
 | 
			
		||||
#pragma once
 | 
			
		||||
 | 
			
		||||
#include "util/UnsupportedType.hpp"
 | 
			
		||||
 | 
			
		||||
#include <cstdint>
 | 
			
		||||
#include <string>
 | 
			
		||||
#include <type_traits>
 | 
			
		||||
#include <variant>
 | 
			
		||||
 | 
			
		||||
namespace util::config {
 | 
			
		||||
 | 
			
		||||
/** @brief Custom clio config types */
 | 
			
		||||
enum class ConfigType { Integer, String, Double, Boolean };
 | 
			
		||||
 | 
			
		||||
/** @brief Represents the supported Config Values */
 | 
			
		||||
using Value = std::variant<int64_t, std::string, bool, double>;
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * @brief Get the corresponding clio config type
 | 
			
		||||
 *
 | 
			
		||||
 * @tparam Type The type to get the corresponding ConfigType for
 | 
			
		||||
 * @return The corresponding ConfigType
 | 
			
		||||
 */
 | 
			
		||||
template <typename Type>
 | 
			
		||||
constexpr ConfigType
 | 
			
		||||
getType()
 | 
			
		||||
{
 | 
			
		||||
    if constexpr (std::is_same_v<Type, int64_t>) {
 | 
			
		||||
        return ConfigType::Integer;
 | 
			
		||||
    } else if constexpr (std::is_same_v<Type, std::string>) {
 | 
			
		||||
        return ConfigType::String;
 | 
			
		||||
    } else if constexpr (std::is_same_v<Type, double>) {
 | 
			
		||||
        return ConfigType::Double;
 | 
			
		||||
    } else if constexpr (std::is_same_v<Type, bool>) {
 | 
			
		||||
        return ConfigType::Boolean;
 | 
			
		||||
    } else {
 | 
			
		||||
        static_assert(util::Unsupported<Type>, "Wrong config type");
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
}  // namespace util::config
 | 
			
		||||
@@ -21,6 +21,7 @@
 | 
			
		||||
 | 
			
		||||
#include "util/Assert.hpp"
 | 
			
		||||
#include "util/newconfig/ConfigValue.hpp"
 | 
			
		||||
#include "util/newconfig/Types.hpp"
 | 
			
		||||
 | 
			
		||||
#include <cstdint>
 | 
			
		||||
#include <string>
 | 
			
		||||
@@ -55,9 +56,9 @@ double
 | 
			
		||||
ValueView::asDouble() const
 | 
			
		||||
{
 | 
			
		||||
    if (configVal_.get().hasValue()) {
 | 
			
		||||
        if (type() == ConfigType::Double) {
 | 
			
		||||
        if (type() == ConfigType::Double)
 | 
			
		||||
            return std::get<double>(configVal_.get().getValue());
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        if (type() == ConfigType::Integer)
 | 
			
		||||
            return static_cast<double>(std::get<int64_t>(configVal_.get().getValue()));
 | 
			
		||||
    }
 | 
			
		||||
 
 | 
			
		||||
@@ -21,6 +21,7 @@
 | 
			
		||||
 | 
			
		||||
#include "util/Assert.hpp"
 | 
			
		||||
#include "util/newconfig/ConfigValue.hpp"
 | 
			
		||||
#include "util/newconfig/Types.hpp"
 | 
			
		||||
 | 
			
		||||
#include <fmt/core.h>
 | 
			
		||||
 | 
			
		||||
@@ -84,7 +85,7 @@ public:
 | 
			
		||||
                return static_cast<T>(val);
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
        ASSERT(false, "Value view is not of any Int type");
 | 
			
		||||
        ASSERT(false, "Value view is not of Int type");
 | 
			
		||||
        return 0;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -39,8 +39,10 @@
 | 
			
		||||
#include <boost/beast/websocket/stream_base.hpp>
 | 
			
		||||
#include <boost/system/errc.hpp>
 | 
			
		||||
 | 
			
		||||
#include <atomic>
 | 
			
		||||
#include <chrono>
 | 
			
		||||
#include <expected>
 | 
			
		||||
#include <memory>
 | 
			
		||||
#include <optional>
 | 
			
		||||
#include <string>
 | 
			
		||||
#include <utility>
 | 
			
		||||
@@ -123,15 +125,20 @@ private:
 | 
			
		||||
    static void
 | 
			
		||||
    withTimeout(Operation&& operation, boost::asio::yield_context yield, std::chrono::steady_clock::duration timeout)
 | 
			
		||||
    {
 | 
			
		||||
        auto isCompleted = std::make_shared<bool>(false);
 | 
			
		||||
        boost::asio::cancellation_signal cancellationSignal;
 | 
			
		||||
        auto cyield = boost::asio::bind_cancellation_slot(cancellationSignal.slot(), yield);
 | 
			
		||||
 | 
			
		||||
        boost::asio::steady_timer timer{boost::asio::get_associated_executor(cyield), timeout};
 | 
			
		||||
        timer.async_wait([&cancellationSignal](boost::system::error_code errorCode) {
 | 
			
		||||
            if (!errorCode)
 | 
			
		||||
 | 
			
		||||
        // The timer below can be called with no error code even if the operation is completed before the timeout, so we
 | 
			
		||||
        // need an additional flag here
 | 
			
		||||
        timer.async_wait([&cancellationSignal, isCompleted](boost::system::error_code errorCode) {
 | 
			
		||||
            if (!errorCode and not *isCompleted)
 | 
			
		||||
                cancellationSignal.emit(boost::asio::cancellation_type::terminal);
 | 
			
		||||
        });
 | 
			
		||||
        operation(cyield);
 | 
			
		||||
        *isCompleted = true;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    static boost::system::error_code
 | 
			
		||||
 
 | 
			
		||||
@@ -43,9 +43,11 @@
 | 
			
		||||
 | 
			
		||||
#include <algorithm>
 | 
			
		||||
#include <expected>
 | 
			
		||||
#include <functional>
 | 
			
		||||
#include <iterator>
 | 
			
		||||
#include <optional>
 | 
			
		||||
#include <string>
 | 
			
		||||
#include <string_view>
 | 
			
		||||
#include <utility>
 | 
			
		||||
#include <vector>
 | 
			
		||||
 | 
			
		||||
@@ -75,6 +77,14 @@ TestWsConnection::send(std::string const& message, boost::asio::yield_context yi
 | 
			
		||||
    return std::nullopt;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void
 | 
			
		||||
TestWsConnection::sendPing(boost::beast::websocket::ping_data const& data, boost::asio::yield_context yield)
 | 
			
		||||
{
 | 
			
		||||
    boost::beast::error_code errorCode;
 | 
			
		||||
    ws_.async_ping(data, yield[errorCode]);
 | 
			
		||||
    [&]() { ASSERT_FALSE(errorCode) << errorCode.message(); }();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
std::optional<std::string>
 | 
			
		||||
TestWsConnection::receive(boost::asio::yield_context yield)
 | 
			
		||||
{
 | 
			
		||||
@@ -105,6 +115,20 @@ TestWsConnection::headers() const
 | 
			
		||||
    return headers_;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void
 | 
			
		||||
TestWsConnection::setControlFrameCallback(
 | 
			
		||||
    std::function<void(boost::beast::websocket::frame_type, std::string_view)> callback
 | 
			
		||||
)
 | 
			
		||||
{
 | 
			
		||||
    ws_.control_callback(std::move(callback));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void
 | 
			
		||||
TestWsConnection::resetControlFrameCallback()
 | 
			
		||||
{
 | 
			
		||||
    ws_.control_callback();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TestWsServer::TestWsServer(asio::io_context& context, std::string const& host) : acceptor_(context)
 | 
			
		||||
{
 | 
			
		||||
    auto endpoint = asio::ip::tcp::endpoint(boost::asio::ip::make_address(host), 0);
 | 
			
		||||
 
 | 
			
		||||
@@ -25,6 +25,7 @@
 | 
			
		||||
#include <boost/asio/ip/tcp.hpp>
 | 
			
		||||
#include <boost/asio/spawn.hpp>
 | 
			
		||||
#include <boost/beast/core/tcp_stream.hpp>
 | 
			
		||||
#include <boost/beast/websocket/rfc6455.hpp>
 | 
			
		||||
#include <boost/beast/websocket/stream.hpp>
 | 
			
		||||
 | 
			
		||||
#include <expected>
 | 
			
		||||
@@ -54,6 +55,9 @@ public:
 | 
			
		||||
    std::optional<std::string>
 | 
			
		||||
    send(std::string const& message, boost::asio::yield_context yield);
 | 
			
		||||
 | 
			
		||||
    void
 | 
			
		||||
    sendPing(boost::beast::websocket::ping_data const& data, boost::asio::yield_context yield);
 | 
			
		||||
 | 
			
		||||
    // returns nullopt if the connection is closed
 | 
			
		||||
    std::optional<std::string>
 | 
			
		||||
    receive(boost::asio::yield_context yield);
 | 
			
		||||
@@ -63,6 +67,12 @@ public:
 | 
			
		||||
 | 
			
		||||
    std::vector<util::requests::HttpHeader> const&
 | 
			
		||||
    headers() const;
 | 
			
		||||
 | 
			
		||||
    void
 | 
			
		||||
    setControlFrameCallback(std::function<void(boost::beast::websocket::frame_type, std::string_view)> callback);
 | 
			
		||||
 | 
			
		||||
    void
 | 
			
		||||
    resetControlFrameCallback();
 | 
			
		||||
};
 | 
			
		||||
using TestWsConnectionPtr = std::unique_ptr<TestWsConnection>;
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -20,13 +20,25 @@
 | 
			
		||||
#pragma once
 | 
			
		||||
 | 
			
		||||
#include "util/newconfig/Array.hpp"
 | 
			
		||||
#include "util/newconfig/ConfigConstraints.hpp"
 | 
			
		||||
#include "util/newconfig/ConfigDefinition.hpp"
 | 
			
		||||
#include "util/newconfig/ConfigValue.hpp"
 | 
			
		||||
#include "util/newconfig/Types.hpp"
 | 
			
		||||
 | 
			
		||||
#include <gtest/gtest.h>
 | 
			
		||||
 | 
			
		||||
using namespace util::config;
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * @brief A mock ClioConfigDefinition for testing purposes.
 | 
			
		||||
 *
 | 
			
		||||
 * In the actual Clio configuration, arrays typically hold optional values, meaning users are not required to
 | 
			
		||||
 * provide values for them.
 | 
			
		||||
 *
 | 
			
		||||
 * For primitive types (i.e., single specific values), some are mandatory and must be explicitly defined in the
 | 
			
		||||
 * user's configuration file, including both the key and the corresponding value, while some are optional
 | 
			
		||||
 */
 | 
			
		||||
 | 
			
		||||
inline ClioConfigDefinition
 | 
			
		||||
generateConfig()
 | 
			
		||||
{
 | 
			
		||||
@@ -36,60 +48,167 @@ generateConfig()
 | 
			
		||||
        {"header.admin", ConfigValue{ConfigType::Boolean}.defaultValue(true)},
 | 
			
		||||
        {"header.sub.sub2Value", ConfigValue{ConfigType::String}.defaultValue("TSM")},
 | 
			
		||||
        {"ip", ConfigValue{ConfigType::Double}.defaultValue(444.22)},
 | 
			
		||||
        {"array.[].sub",
 | 
			
		||||
         Array{
 | 
			
		||||
             ConfigValue{ConfigType::Double}.defaultValue(111.11), ConfigValue{ConfigType::Double}.defaultValue(4321.55)
 | 
			
		||||
         }},
 | 
			
		||||
        {"array.[].sub2",
 | 
			
		||||
         Array{
 | 
			
		||||
             ConfigValue{ConfigType::String}.defaultValue("subCategory"),
 | 
			
		||||
             ConfigValue{ConfigType::String}.defaultValue("temporary")
 | 
			
		||||
         }},
 | 
			
		||||
        {"higher.[].low.section", Array{ConfigValue{ConfigType::String}.defaultValue("true")}},
 | 
			
		||||
        {"higher.[].low.admin", Array{ConfigValue{ConfigType::Boolean}.defaultValue(false)}},
 | 
			
		||||
        {"dosguard.whitelist.[]",
 | 
			
		||||
         Array{
 | 
			
		||||
             ConfigValue{ConfigType::String}.defaultValue("125.5.5.2"),
 | 
			
		||||
             ConfigValue{ConfigType::String}.defaultValue("204.2.2.2")
 | 
			
		||||
         }},
 | 
			
		||||
        {"dosguard.port", ConfigValue{ConfigType::Integer}.defaultValue(55555)}
 | 
			
		||||
        {"array.[].sub", Array{ConfigValue{ConfigType::Double}}},
 | 
			
		||||
        {"array.[].sub2", Array{ConfigValue{ConfigType::String}.optional()}},
 | 
			
		||||
        {"higher.[].low.section", Array{ConfigValue{ConfigType::String}.withConstraint(validateChannelName)}},
 | 
			
		||||
        {"higher.[].low.admin", Array{ConfigValue{ConfigType::Boolean}}},
 | 
			
		||||
        {"dosguard.whitelist.[]", Array{ConfigValue{ConfigType::String}.optional()}},
 | 
			
		||||
        {"dosguard.port", ConfigValue{ConfigType::Integer}.defaultValue(55555).withConstraint(validatePort)},
 | 
			
		||||
        {"optional.withDefault", ConfigValue{ConfigType::Double}.defaultValue(0.0).optional()},
 | 
			
		||||
        {"optional.withNoDefault", ConfigValue{ConfigType::Double}.optional()},
 | 
			
		||||
        {"requireValue", ConfigValue{ConfigType::String}}
 | 
			
		||||
    };
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/* The config definition above would look like this structure in config.json:
 | 
			
		||||
"header": {
 | 
			
		||||
   "text1": "value",
 | 
			
		||||
   "port": 123,
 | 
			
		||||
   "admin": true,
 | 
			
		||||
   "sub": {
 | 
			
		||||
       "sub2Value": "TSM"
 | 
			
		||||
   }
 | 
			
		||||
 },
 | 
			
		||||
 "ip": 444.22,
 | 
			
		||||
 "array": [
 | 
			
		||||
   {
 | 
			
		||||
       "sub": 111.11,
 | 
			
		||||
       "sub2": "subCategory"
 | 
			
		||||
   },
 | 
			
		||||
   {
 | 
			
		||||
       "sub": 4321.55,
 | 
			
		||||
       "sub2": "temporary"
 | 
			
		||||
   }
 | 
			
		||||
 ],
 | 
			
		||||
 "higher": [
 | 
			
		||||
   {
 | 
			
		||||
       "low": {
 | 
			
		||||
           "section": "true",
 | 
			
		||||
           "admin": false
 | 
			
		||||
/* The config definition above would look like this structure in config.json
 | 
			
		||||
{
 | 
			
		||||
    "header": {
 | 
			
		||||
       "text1": "value",
 | 
			
		||||
       "port": 321,
 | 
			
		||||
       "admin": true,
 | 
			
		||||
       "sub": {
 | 
			
		||||
           "sub2Value": "TSM"
 | 
			
		||||
       }
 | 
			
		||||
   }
 | 
			
		||||
 ],
 | 
			
		||||
 "dosguard":  {
 | 
			
		||||
    "whitelist": [
 | 
			
		||||
        "125.5.5.2", "204.2.2.2"
 | 
			
		||||
    ],
 | 
			
		||||
    "port" : 55555
 | 
			
		||||
 },
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
     },
 | 
			
		||||
     "ip": 444.22,
 | 
			
		||||
     "array": [
 | 
			
		||||
       {
 | 
			
		||||
           "sub": //optional for user to include
 | 
			
		||||
           "sub2": //optional for user to include
 | 
			
		||||
       },
 | 
			
		||||
     ],
 | 
			
		||||
     "higher": [
 | 
			
		||||
       {
 | 
			
		||||
           "low": {
 | 
			
		||||
               "section": //optional for user to include
 | 
			
		||||
               "admin": //optional for user to include
 | 
			
		||||
           }
 | 
			
		||||
       }
 | 
			
		||||
     ],
 | 
			
		||||
     "dosguard":  {
 | 
			
		||||
        "whitelist": [
 | 
			
		||||
            // mandatory for user to include
 | 
			
		||||
        ],
 | 
			
		||||
        "port" : 55555
 | 
			
		||||
        },
 | 
			
		||||
    },
 | 
			
		||||
    "optional" : {
 | 
			
		||||
        "withDefault" : 0.0,
 | 
			
		||||
        "withNoDefault" :  //optional for user to include
 | 
			
		||||
        },
 | 
			
		||||
    "requireValue" : // value must be provided by user
 | 
			
		||||
    }
 | 
			
		||||
*/
 | 
			
		||||
 | 
			
		||||
/* Used to test overwriting default values in ClioConfigDefinition Above */
 | 
			
		||||
constexpr static auto JSONData = R"JSON(
 | 
			
		||||
    {
 | 
			
		||||
    "header": {
 | 
			
		||||
       "text1": "value",
 | 
			
		||||
       "port": 321,
 | 
			
		||||
       "admin": false,
 | 
			
		||||
       "sub": {
 | 
			
		||||
           "sub2Value": "TSM"
 | 
			
		||||
       }
 | 
			
		||||
     },
 | 
			
		||||
     "array": [
 | 
			
		||||
       {
 | 
			
		||||
           "sub": 111.11,
 | 
			
		||||
           "sub2": "subCategory"
 | 
			
		||||
       },
 | 
			
		||||
       {
 | 
			
		||||
           "sub": 4321.55,
 | 
			
		||||
           "sub2": "temporary"
 | 
			
		||||
       },
 | 
			
		||||
       {
 | 
			
		||||
           "sub": 5555.44,
 | 
			
		||||
           "sub2": "london"
 | 
			
		||||
       }
 | 
			
		||||
     ],
 | 
			
		||||
      "higher": [
 | 
			
		||||
       {
 | 
			
		||||
           "low": {
 | 
			
		||||
               "section": "WebServer",
 | 
			
		||||
               "admin": false
 | 
			
		||||
           }
 | 
			
		||||
       }
 | 
			
		||||
     ],
 | 
			
		||||
     "dosguard":  {
 | 
			
		||||
        "whitelist": [
 | 
			
		||||
            "125.5.5.1", "204.2.2.1"
 | 
			
		||||
        ],
 | 
			
		||||
        "port" : 44444
 | 
			
		||||
        },
 | 
			
		||||
    "optional" : {
 | 
			
		||||
        "withDefault" : 0.0
 | 
			
		||||
        },
 | 
			
		||||
    "requireValue" : "required"
 | 
			
		||||
    }
 | 
			
		||||
)JSON";
 | 
			
		||||
 | 
			
		||||
/* After parsing jsonValue and populating it into ClioConfig, It will look like this below in json format;
 | 
			
		||||
{
 | 
			
		||||
    "header": {
 | 
			
		||||
       "text1": "value",
 | 
			
		||||
       "port": 321,
 | 
			
		||||
       "admin": false,
 | 
			
		||||
       "sub": {
 | 
			
		||||
           "sub2Value": "TSM"
 | 
			
		||||
       }
 | 
			
		||||
     },
 | 
			
		||||
     "ip": 444.22,
 | 
			
		||||
     "array": [
 | 
			
		||||
        {
 | 
			
		||||
           "sub": 111.11,
 | 
			
		||||
           "sub2": "subCategory"
 | 
			
		||||
       },
 | 
			
		||||
       {
 | 
			
		||||
           "sub": 4321.55,
 | 
			
		||||
           "sub2": "temporary"
 | 
			
		||||
       },
 | 
			
		||||
       {
 | 
			
		||||
           "sub": 5555.44,
 | 
			
		||||
           "sub2": "london"
 | 
			
		||||
       }
 | 
			
		||||
     ],
 | 
			
		||||
     "higher": [
 | 
			
		||||
       {
 | 
			
		||||
           "low": {
 | 
			
		||||
               "section": "WebServer",
 | 
			
		||||
               "admin": false
 | 
			
		||||
           }
 | 
			
		||||
       }
 | 
			
		||||
     ],
 | 
			
		||||
     "dosguard":  {
 | 
			
		||||
        "whitelist": [
 | 
			
		||||
            "125.5.5.1", "204.2.2.1"
 | 
			
		||||
        ],
 | 
			
		||||
        "port" : 44444
 | 
			
		||||
        }
 | 
			
		||||
    },
 | 
			
		||||
    "optional" : {
 | 
			
		||||
        "withDefault" : 0.0
 | 
			
		||||
        },
 | 
			
		||||
    "requireValue" : "required"
 | 
			
		||||
    }
 | 
			
		||||
*/
 | 
			
		||||
 | 
			
		||||
// Invalid Json key/values
 | 
			
		||||
constexpr static auto invalidJSONData = R"JSON(
 | 
			
		||||
{
 | 
			
		||||
    "header": {
 | 
			
		||||
        "port": "999",
 | 
			
		||||
        "admin": "true"
 | 
			
		||||
    },
 | 
			
		||||
    "dosguard": {
 | 
			
		||||
        "whitelist": [
 | 
			
		||||
            false
 | 
			
		||||
        ]
 | 
			
		||||
    },
 | 
			
		||||
    "idk": true,
 | 
			
		||||
    "requireValue" : "required",
 | 
			
		||||
    "optional" : {
 | 
			
		||||
        "withDefault" : "0.0"
 | 
			
		||||
        }
 | 
			
		||||
}
 | 
			
		||||
)JSON";
 | 
			
		||||
 
 | 
			
		||||
@@ -133,12 +133,13 @@ target_sources(
 | 
			
		||||
          web/RPCServerHandlerTests.cpp
 | 
			
		||||
          web/ServerTests.cpp
 | 
			
		||||
          # New Config
 | 
			
		||||
          util/newconfig/ArrayViewTests.cpp
 | 
			
		||||
          util/newconfig/ObjectViewTests.cpp
 | 
			
		||||
          util/newconfig/ValueViewTests.cpp
 | 
			
		||||
          util/newconfig/ArrayTests.cpp
 | 
			
		||||
          util/newconfig/ConfigValueTests.cpp
 | 
			
		||||
          util/newconfig/ArrayViewTests.cpp
 | 
			
		||||
          util/newconfig/ClioConfigDefinitionTests.cpp
 | 
			
		||||
          util/newconfig/ConfigValueTests.cpp
 | 
			
		||||
          util/newconfig/ObjectViewTests.cpp
 | 
			
		||||
          util/newconfig/JsonConfigFileTests.cpp
 | 
			
		||||
          util/newconfig/ValueViewTests.cpp
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
configure_file(test_data/cert.pem ${CMAKE_BINARY_DIR}/tests/unit/test_data/cert.pem COPYONLY)
 | 
			
		||||
 
 | 
			
		||||
@@ -49,7 +49,7 @@ struct AmendmentCenterTest : util::prometheus::WithPrometheus, MockBackendTest,
 | 
			
		||||
TEST_F(AmendmentCenterTest, AllAmendmentsFromLibXRPLAreSupported)
 | 
			
		||||
{
 | 
			
		||||
    for (auto const& [name, _] : ripple::allAmendments()) {
 | 
			
		||||
        ASSERT_TRUE(amendmentCenter.isSupported(name)) << "XRPL amendment not supported by Clio: " << name;
 | 
			
		||||
        EXPECT_TRUE(amendmentCenter.isSupported(name)) << "XRPL amendment not supported by Clio: " << name;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    ASSERT_EQ(amendmentCenter.getSupported().size(), ripple::allAmendments().size());
 | 
			
		||||
 
 | 
			
		||||
@@ -280,15 +280,12 @@ TEST_F(LoadBalancerOnDisconnectHookTests, source0Disconnects)
 | 
			
		||||
    EXPECT_CALL(sourceFactory_.sourceAt(0), setForwarding(false));
 | 
			
		||||
    EXPECT_CALL(sourceFactory_.sourceAt(1), isConnected()).WillOnce(Return(true));
 | 
			
		||||
    EXPECT_CALL(sourceFactory_.sourceAt(1), setForwarding(true));
 | 
			
		||||
    sourceFactory_.callbacksAt(0).onDisconnect();
 | 
			
		||||
    sourceFactory_.callbacksAt(0).onDisconnect(true);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(LoadBalancerOnDisconnectHookTests, source1Disconnects)
 | 
			
		||||
{
 | 
			
		||||
    EXPECT_CALL(sourceFactory_.sourceAt(0), isConnected()).WillOnce(Return(true));
 | 
			
		||||
    EXPECT_CALL(sourceFactory_.sourceAt(0), setForwarding(true));
 | 
			
		||||
    EXPECT_CALL(sourceFactory_.sourceAt(1), setForwarding(false));
 | 
			
		||||
    sourceFactory_.callbacksAt(1).onDisconnect();
 | 
			
		||||
    sourceFactory_.callbacksAt(1).onDisconnect(false);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(LoadBalancerOnDisconnectHookTests, source0DisconnectsAndConnectsBack)
 | 
			
		||||
@@ -297,29 +294,25 @@ TEST_F(LoadBalancerOnDisconnectHookTests, source0DisconnectsAndConnectsBack)
 | 
			
		||||
    EXPECT_CALL(sourceFactory_.sourceAt(0), setForwarding(false));
 | 
			
		||||
    EXPECT_CALL(sourceFactory_.sourceAt(1), isConnected()).WillOnce(Return(true));
 | 
			
		||||
    EXPECT_CALL(sourceFactory_.sourceAt(1), setForwarding(true));
 | 
			
		||||
    sourceFactory_.callbacksAt(0).onDisconnect();
 | 
			
		||||
    sourceFactory_.callbacksAt(0).onDisconnect(true);
 | 
			
		||||
 | 
			
		||||
    sourceFactory_.callbacksAt(0).onConnect();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(LoadBalancerOnDisconnectHookTests, source1DisconnectsAndConnectsBack)
 | 
			
		||||
{
 | 
			
		||||
    EXPECT_CALL(sourceFactory_.sourceAt(0), isConnected()).WillOnce(Return(true));
 | 
			
		||||
    EXPECT_CALL(sourceFactory_.sourceAt(0), setForwarding(true));
 | 
			
		||||
    EXPECT_CALL(sourceFactory_.sourceAt(1), setForwarding(false));
 | 
			
		||||
    sourceFactory_.callbacksAt(1).onDisconnect();
 | 
			
		||||
 | 
			
		||||
    sourceFactory_.callbacksAt(1).onDisconnect(false);
 | 
			
		||||
    sourceFactory_.callbacksAt(1).onConnect();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(LoadBalancerOnConnectHookTests, bothSourcesDisconnectAndConnectBack)
 | 
			
		||||
{
 | 
			
		||||
    EXPECT_CALL(sourceFactory_.sourceAt(0), isConnected()).Times(2).WillRepeatedly(Return(false));
 | 
			
		||||
    EXPECT_CALL(sourceFactory_.sourceAt(0), setForwarding(false)).Times(2);
 | 
			
		||||
    EXPECT_CALL(sourceFactory_.sourceAt(1), isConnected()).Times(2).WillRepeatedly(Return(false));
 | 
			
		||||
    EXPECT_CALL(sourceFactory_.sourceAt(1), setForwarding(false)).Times(2);
 | 
			
		||||
    sourceFactory_.callbacksAt(0).onDisconnect();
 | 
			
		||||
    sourceFactory_.callbacksAt(1).onDisconnect();
 | 
			
		||||
    EXPECT_CALL(sourceFactory_.sourceAt(0), isConnected()).WillOnce(Return(false));
 | 
			
		||||
    EXPECT_CALL(sourceFactory_.sourceAt(0), setForwarding(false));
 | 
			
		||||
    EXPECT_CALL(sourceFactory_.sourceAt(1), isConnected()).WillOnce(Return(false));
 | 
			
		||||
    EXPECT_CALL(sourceFactory_.sourceAt(1), setForwarding(false));
 | 
			
		||||
    sourceFactory_.callbacksAt(0).onDisconnect(true);
 | 
			
		||||
    sourceFactory_.callbacksAt(1).onDisconnect(false);
 | 
			
		||||
 | 
			
		||||
    EXPECT_CALL(sourceFactory_.sourceAt(0), isConnected()).WillOnce(Return(true));
 | 
			
		||||
    EXPECT_CALL(sourceFactory_.sourceAt(0), setForwarding(true));
 | 
			
		||||
@@ -362,12 +355,7 @@ TEST_F(LoadBalancer3SourcesTests, forwardingUpdate)
 | 
			
		||||
    sourceFactory_.callbacksAt(1).onConnect();
 | 
			
		||||
 | 
			
		||||
    // Source 0 got disconnected
 | 
			
		||||
    EXPECT_CALL(sourceFactory_.sourceAt(0), isConnected()).WillOnce(Return(false));
 | 
			
		||||
    EXPECT_CALL(sourceFactory_.sourceAt(0), setForwarding(false));
 | 
			
		||||
    EXPECT_CALL(sourceFactory_.sourceAt(1), isConnected()).WillOnce(Return(true));
 | 
			
		||||
    EXPECT_CALL(sourceFactory_.sourceAt(1), setForwarding(true));
 | 
			
		||||
    EXPECT_CALL(sourceFactory_.sourceAt(2), setForwarding(false));  // only source 1 must be forwarding
 | 
			
		||||
    sourceFactory_.callbacksAt(0).onDisconnect();
 | 
			
		||||
    sourceFactory_.callbacksAt(0).onDisconnect(false);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
struct LoadBalancerLoadInitialLedgerTests : LoadBalancerOnConnectHookTests {
 | 
			
		||||
 
 | 
			
		||||
@@ -20,27 +20,33 @@
 | 
			
		||||
#include "etl/impl/SubscriptionSource.hpp"
 | 
			
		||||
#include "util/LoggerFixtures.hpp"
 | 
			
		||||
#include "util/MockNetworkValidatedLedgers.hpp"
 | 
			
		||||
#include "util/MockPrometheus.hpp"
 | 
			
		||||
#include "util/MockSubscriptionManager.hpp"
 | 
			
		||||
#include "util/TestWsServer.hpp"
 | 
			
		||||
#include "util/prometheus/Gauge.hpp"
 | 
			
		||||
 | 
			
		||||
#include <boost/asio/io_context.hpp>
 | 
			
		||||
#include <boost/asio/spawn.hpp>
 | 
			
		||||
#include <boost/json/object.hpp>
 | 
			
		||||
#include <boost/json/serialize.hpp>
 | 
			
		||||
#include <fmt/core.h>
 | 
			
		||||
#include <gmock/gmock.h>
 | 
			
		||||
#include <gtest/gtest.h>
 | 
			
		||||
 | 
			
		||||
#include <chrono>
 | 
			
		||||
#include <cstdint>
 | 
			
		||||
#include <cstdlib>
 | 
			
		||||
#include <optional>
 | 
			
		||||
#include <string>
 | 
			
		||||
#include <thread>
 | 
			
		||||
#include <utility>
 | 
			
		||||
 | 
			
		||||
using namespace etl::impl;
 | 
			
		||||
using testing::MockFunction;
 | 
			
		||||
using testing::StrictMock;
 | 
			
		||||
 | 
			
		||||
struct SubscriptionSourceConnectionTests : public NoLoggerFixture {
 | 
			
		||||
    SubscriptionSourceConnectionTests()
 | 
			
		||||
struct SubscriptionSourceConnectionTestsBase : public NoLoggerFixture {
 | 
			
		||||
    SubscriptionSourceConnectionTestsBase()
 | 
			
		||||
    {
 | 
			
		||||
        subscriptionSource_.run();
 | 
			
		||||
    }
 | 
			
		||||
@@ -52,7 +58,7 @@ struct SubscriptionSourceConnectionTests : public NoLoggerFixture {
 | 
			
		||||
    StrictMockSubscriptionManagerSharedPtr subscriptionManager_;
 | 
			
		||||
 | 
			
		||||
    StrictMock<MockFunction<void()>> onConnectHook_;
 | 
			
		||||
    StrictMock<MockFunction<void()>> onDisconnectHook_;
 | 
			
		||||
    StrictMock<MockFunction<void(bool)>> onDisconnectHook_;
 | 
			
		||||
    StrictMock<MockFunction<void()>> onLedgerClosedHook_;
 | 
			
		||||
 | 
			
		||||
    SubscriptionSource subscriptionSource_{
 | 
			
		||||
@@ -64,8 +70,8 @@ struct SubscriptionSourceConnectionTests : public NoLoggerFixture {
 | 
			
		||||
        onConnectHook_.AsStdFunction(),
 | 
			
		||||
        onDisconnectHook_.AsStdFunction(),
 | 
			
		||||
        onLedgerClosedHook_.AsStdFunction(),
 | 
			
		||||
        std::chrono::milliseconds(1),
 | 
			
		||||
        std::chrono::milliseconds(1)
 | 
			
		||||
        std::chrono::milliseconds(5),
 | 
			
		||||
        std::chrono::milliseconds(5)
 | 
			
		||||
    };
 | 
			
		||||
 | 
			
		||||
    [[maybe_unused]] TestWsConnection
 | 
			
		||||
@@ -90,15 +96,17 @@ struct SubscriptionSourceConnectionTests : public NoLoggerFixture {
 | 
			
		||||
    }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
struct SubscriptionSourceConnectionTests : util::prometheus::WithPrometheus, SubscriptionSourceConnectionTestsBase {};
 | 
			
		||||
 | 
			
		||||
TEST_F(SubscriptionSourceConnectionTests, ConnectionFailed)
 | 
			
		||||
{
 | 
			
		||||
    EXPECT_CALL(onDisconnectHook_, Call()).WillOnce([this]() { subscriptionSource_.stop(); });
 | 
			
		||||
    EXPECT_CALL(onDisconnectHook_, Call(false)).WillOnce([this]() { subscriptionSource_.stop(); });
 | 
			
		||||
    ioContext_.run();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(SubscriptionSourceConnectionTests, ConnectionFailed_Retry_ConnectionFailed)
 | 
			
		||||
{
 | 
			
		||||
    EXPECT_CALL(onDisconnectHook_, Call()).WillOnce([]() {}).WillOnce([this]() { subscriptionSource_.stop(); });
 | 
			
		||||
    EXPECT_CALL(onDisconnectHook_, Call(false)).WillOnce([]() {}).WillOnce([this]() { subscriptionSource_.stop(); });
 | 
			
		||||
    ioContext_.run();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -110,7 +118,19 @@ TEST_F(SubscriptionSourceConnectionTests, ReadError)
 | 
			
		||||
    });
 | 
			
		||||
 | 
			
		||||
    EXPECT_CALL(onConnectHook_, Call());
 | 
			
		||||
    EXPECT_CALL(onDisconnectHook_, Call()).WillOnce([this]() { subscriptionSource_.stop(); });
 | 
			
		||||
    EXPECT_CALL(onDisconnectHook_, Call(false)).WillOnce([this]() { subscriptionSource_.stop(); });
 | 
			
		||||
    ioContext_.run();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(SubscriptionSourceConnectionTests, ReadTimeout)
 | 
			
		||||
{
 | 
			
		||||
    boost::asio::spawn(ioContext_, [this](boost::asio::yield_context yield) {
 | 
			
		||||
        auto connection = serverConnection(yield);
 | 
			
		||||
        std::this_thread::sleep_for(std::chrono::milliseconds{10});
 | 
			
		||||
    });
 | 
			
		||||
 | 
			
		||||
    EXPECT_CALL(onConnectHook_, Call());
 | 
			
		||||
    EXPECT_CALL(onDisconnectHook_, Call(false)).WillOnce([this]() { subscriptionSource_.stop(); });
 | 
			
		||||
    ioContext_.run();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -124,7 +144,7 @@ TEST_F(SubscriptionSourceConnectionTests, ReadError_Reconnect)
 | 
			
		||||
    });
 | 
			
		||||
 | 
			
		||||
    EXPECT_CALL(onConnectHook_, Call()).Times(2);
 | 
			
		||||
    EXPECT_CALL(onDisconnectHook_, Call()).WillOnce([]() {}).WillOnce([this]() { subscriptionSource_.stop(); });
 | 
			
		||||
    EXPECT_CALL(onDisconnectHook_, Call(false)).WillOnce([]() {}).WillOnce([this]() { subscriptionSource_.stop(); });
 | 
			
		||||
    ioContext_.run();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -137,14 +157,14 @@ TEST_F(SubscriptionSourceConnectionTests, IsConnected)
 | 
			
		||||
    });
 | 
			
		||||
 | 
			
		||||
    EXPECT_CALL(onConnectHook_, Call()).WillOnce([this]() { EXPECT_TRUE(subscriptionSource_.isConnected()); });
 | 
			
		||||
    EXPECT_CALL(onDisconnectHook_, Call()).WillOnce([this]() {
 | 
			
		||||
    EXPECT_CALL(onDisconnectHook_, Call(false)).WillOnce([this]() {
 | 
			
		||||
        EXPECT_FALSE(subscriptionSource_.isConnected());
 | 
			
		||||
        subscriptionSource_.stop();
 | 
			
		||||
    });
 | 
			
		||||
    ioContext_.run();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
struct SubscriptionSourceReadTests : public SubscriptionSourceConnectionTests {
 | 
			
		||||
struct SubscriptionSourceReadTestsBase : public SubscriptionSourceConnectionTestsBase {
 | 
			
		||||
    [[maybe_unused]] TestWsConnection
 | 
			
		||||
    connectAndSendMessage(std::string const message, boost::asio::yield_context yield)
 | 
			
		||||
    {
 | 
			
		||||
@@ -155,6 +175,8 @@ struct SubscriptionSourceReadTests : public SubscriptionSourceConnectionTests {
 | 
			
		||||
    }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
struct SubscriptionSourceReadTests : util::prometheus::WithPrometheus, SubscriptionSourceReadTestsBase {};
 | 
			
		||||
 | 
			
		||||
TEST_F(SubscriptionSourceReadTests, GotWrongMessage_Reconnect)
 | 
			
		||||
{
 | 
			
		||||
    boost::asio::spawn(ioContext_, [this](boost::asio::yield_context yield) {
 | 
			
		||||
@@ -165,7 +187,7 @@ TEST_F(SubscriptionSourceReadTests, GotWrongMessage_Reconnect)
 | 
			
		||||
    });
 | 
			
		||||
 | 
			
		||||
    EXPECT_CALL(onConnectHook_, Call()).Times(2);
 | 
			
		||||
    EXPECT_CALL(onDisconnectHook_, Call()).WillOnce([]() {}).WillOnce([this]() { subscriptionSource_.stop(); });
 | 
			
		||||
    EXPECT_CALL(onDisconnectHook_, Call(false)).WillOnce([]() {}).WillOnce([this]() { subscriptionSource_.stop(); });
 | 
			
		||||
    ioContext_.run();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -177,7 +199,7 @@ TEST_F(SubscriptionSourceReadTests, GotResult)
 | 
			
		||||
    });
 | 
			
		||||
 | 
			
		||||
    EXPECT_CALL(onConnectHook_, Call());
 | 
			
		||||
    EXPECT_CALL(onDisconnectHook_, Call()).WillOnce([this]() { subscriptionSource_.stop(); });
 | 
			
		||||
    EXPECT_CALL(onDisconnectHook_, Call(false)).WillOnce([this]() { subscriptionSource_.stop(); });
 | 
			
		||||
    ioContext_.run();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -189,7 +211,7 @@ TEST_F(SubscriptionSourceReadTests, GotResultWithLedgerIndex)
 | 
			
		||||
    });
 | 
			
		||||
 | 
			
		||||
    EXPECT_CALL(onConnectHook_, Call());
 | 
			
		||||
    EXPECT_CALL(onDisconnectHook_, Call()).WillOnce([this]() { subscriptionSource_.stop(); });
 | 
			
		||||
    EXPECT_CALL(onDisconnectHook_, Call(false)).WillOnce([this]() { subscriptionSource_.stop(); });
 | 
			
		||||
    EXPECT_CALL(*networkValidatedLedgers_, push(123));
 | 
			
		||||
    ioContext_.run();
 | 
			
		||||
}
 | 
			
		||||
@@ -204,7 +226,7 @@ TEST_F(SubscriptionSourceReadTests, GotResultWithLedgerIndexAsString_Reconnect)
 | 
			
		||||
    });
 | 
			
		||||
 | 
			
		||||
    EXPECT_CALL(onConnectHook_, Call()).Times(2);
 | 
			
		||||
    EXPECT_CALL(onDisconnectHook_, Call()).WillOnce([]() {}).WillOnce([this]() { subscriptionSource_.stop(); });
 | 
			
		||||
    EXPECT_CALL(onDisconnectHook_, Call(false)).WillOnce([]() {}).WillOnce([this]() { subscriptionSource_.stop(); });
 | 
			
		||||
    ioContext_.run();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -218,7 +240,7 @@ TEST_F(SubscriptionSourceReadTests, GotResultWithValidatedLedgersAsNumber_Reconn
 | 
			
		||||
    });
 | 
			
		||||
 | 
			
		||||
    EXPECT_CALL(onConnectHook_, Call()).Times(2);
 | 
			
		||||
    EXPECT_CALL(onDisconnectHook_, Call()).WillOnce([]() {}).WillOnce([this]() { subscriptionSource_.stop(); });
 | 
			
		||||
    EXPECT_CALL(onDisconnectHook_, Call(false)).WillOnce([]() {}).WillOnce([this]() { subscriptionSource_.stop(); });
 | 
			
		||||
    ioContext_.run();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -240,7 +262,7 @@ TEST_F(SubscriptionSourceReadTests, GotResultWithValidatedLedgers)
 | 
			
		||||
    });
 | 
			
		||||
 | 
			
		||||
    EXPECT_CALL(onConnectHook_, Call());
 | 
			
		||||
    EXPECT_CALL(onDisconnectHook_, Call()).WillOnce([this]() { subscriptionSource_.stop(); });
 | 
			
		||||
    EXPECT_CALL(onDisconnectHook_, Call(false)).WillOnce([this]() { subscriptionSource_.stop(); });
 | 
			
		||||
    ioContext_.run();
 | 
			
		||||
 | 
			
		||||
    EXPECT_TRUE(subscriptionSource_.hasLedger(123));
 | 
			
		||||
@@ -266,7 +288,7 @@ TEST_F(SubscriptionSourceReadTests, GotResultWithValidatedLedgersWrongValue_Reco
 | 
			
		||||
    });
 | 
			
		||||
 | 
			
		||||
    EXPECT_CALL(onConnectHook_, Call()).Times(2);
 | 
			
		||||
    EXPECT_CALL(onDisconnectHook_, Call()).WillOnce([]() {}).WillOnce([this]() { subscriptionSource_.stop(); });
 | 
			
		||||
    EXPECT_CALL(onDisconnectHook_, Call(false)).WillOnce([]() {}).WillOnce([this]() { subscriptionSource_.stop(); });
 | 
			
		||||
    ioContext_.run();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -284,7 +306,7 @@ TEST_F(SubscriptionSourceReadTests, GotResultWithLedgerIndexAndValidatedLedgers)
 | 
			
		||||
    });
 | 
			
		||||
 | 
			
		||||
    EXPECT_CALL(onConnectHook_, Call());
 | 
			
		||||
    EXPECT_CALL(onDisconnectHook_, Call()).WillOnce([this]() { subscriptionSource_.stop(); });
 | 
			
		||||
    EXPECT_CALL(onDisconnectHook_, Call(false)).WillOnce([this]() { subscriptionSource_.stop(); });
 | 
			
		||||
    EXPECT_CALL(*networkValidatedLedgers_, push(123));
 | 
			
		||||
    ioContext_.run();
 | 
			
		||||
 | 
			
		||||
@@ -304,7 +326,7 @@ TEST_F(SubscriptionSourceReadTests, GotLedgerClosed)
 | 
			
		||||
    });
 | 
			
		||||
 | 
			
		||||
    EXPECT_CALL(onConnectHook_, Call());
 | 
			
		||||
    EXPECT_CALL(onDisconnectHook_, Call()).WillOnce([this]() { subscriptionSource_.stop(); });
 | 
			
		||||
    EXPECT_CALL(onDisconnectHook_, Call(false)).WillOnce([this]() { subscriptionSource_.stop(); });
 | 
			
		||||
    ioContext_.run();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -319,7 +341,7 @@ TEST_F(SubscriptionSourceReadTests, GotLedgerClosedForwardingIsSet)
 | 
			
		||||
 | 
			
		||||
    EXPECT_CALL(onConnectHook_, Call());
 | 
			
		||||
    EXPECT_CALL(onLedgerClosedHook_, Call());
 | 
			
		||||
    EXPECT_CALL(onDisconnectHook_, Call()).WillOnce([this]() {
 | 
			
		||||
    EXPECT_CALL(onDisconnectHook_, Call(true)).WillOnce([this]() {
 | 
			
		||||
        EXPECT_FALSE(subscriptionSource_.isForwarding());
 | 
			
		||||
        subscriptionSource_.stop();
 | 
			
		||||
    });
 | 
			
		||||
@@ -334,7 +356,7 @@ TEST_F(SubscriptionSourceReadTests, GotLedgerClosedWithLedgerIndex)
 | 
			
		||||
    });
 | 
			
		||||
 | 
			
		||||
    EXPECT_CALL(onConnectHook_, Call());
 | 
			
		||||
    EXPECT_CALL(onDisconnectHook_, Call()).WillOnce([this]() { subscriptionSource_.stop(); });
 | 
			
		||||
    EXPECT_CALL(onDisconnectHook_, Call(false)).WillOnce([this]() { subscriptionSource_.stop(); });
 | 
			
		||||
    EXPECT_CALL(*networkValidatedLedgers_, push(123));
 | 
			
		||||
    ioContext_.run();
 | 
			
		||||
}
 | 
			
		||||
@@ -349,7 +371,7 @@ TEST_F(SubscriptionSourceReadTests, GotLedgerClosedWithLedgerIndexAsString_Recon
 | 
			
		||||
    });
 | 
			
		||||
 | 
			
		||||
    EXPECT_CALL(onConnectHook_, Call()).Times(2);
 | 
			
		||||
    EXPECT_CALL(onDisconnectHook_, Call()).WillOnce([]() {}).WillOnce([this]() { subscriptionSource_.stop(); });
 | 
			
		||||
    EXPECT_CALL(onDisconnectHook_, Call(false)).WillOnce([]() {}).WillOnce([this]() { subscriptionSource_.stop(); });
 | 
			
		||||
    ioContext_.run();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -363,7 +385,7 @@ TEST_F(SubscriptionSourceReadTests, GorLedgerClosedWithValidatedLedgersAsNumber_
 | 
			
		||||
    });
 | 
			
		||||
 | 
			
		||||
    EXPECT_CALL(onConnectHook_, Call()).Times(2);
 | 
			
		||||
    EXPECT_CALL(onDisconnectHook_, Call()).WillOnce([]() {}).WillOnce([this]() { subscriptionSource_.stop(); });
 | 
			
		||||
    EXPECT_CALL(onDisconnectHook_, Call(false)).WillOnce([]() {}).WillOnce([this]() { subscriptionSource_.stop(); });
 | 
			
		||||
    ioContext_.run();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -380,7 +402,7 @@ TEST_F(SubscriptionSourceReadTests, GotLedgerClosedWithValidatedLedgers)
 | 
			
		||||
    });
 | 
			
		||||
 | 
			
		||||
    EXPECT_CALL(onConnectHook_, Call());
 | 
			
		||||
    EXPECT_CALL(onDisconnectHook_, Call()).WillOnce([this]() { subscriptionSource_.stop(); });
 | 
			
		||||
    EXPECT_CALL(onDisconnectHook_, Call(false)).WillOnce([this]() { subscriptionSource_.stop(); });
 | 
			
		||||
    ioContext_.run();
 | 
			
		||||
 | 
			
		||||
    EXPECT_FALSE(subscriptionSource_.hasLedger(0));
 | 
			
		||||
@@ -404,7 +426,7 @@ TEST_F(SubscriptionSourceReadTests, GotLedgerClosedWithLedgerIndexAndValidatedLe
 | 
			
		||||
    });
 | 
			
		||||
 | 
			
		||||
    EXPECT_CALL(onConnectHook_, Call());
 | 
			
		||||
    EXPECT_CALL(onDisconnectHook_, Call()).WillOnce([this]() { subscriptionSource_.stop(); });
 | 
			
		||||
    EXPECT_CALL(onDisconnectHook_, Call(false)).WillOnce([this]() { subscriptionSource_.stop(); });
 | 
			
		||||
    EXPECT_CALL(*networkValidatedLedgers_, push(123));
 | 
			
		||||
    ioContext_.run();
 | 
			
		||||
 | 
			
		||||
@@ -423,7 +445,7 @@ TEST_F(SubscriptionSourceReadTests, GotTransactionIsForwardingFalse)
 | 
			
		||||
    });
 | 
			
		||||
 | 
			
		||||
    EXPECT_CALL(onConnectHook_, Call());
 | 
			
		||||
    EXPECT_CALL(onDisconnectHook_, Call()).WillOnce([this]() { subscriptionSource_.stop(); });
 | 
			
		||||
    EXPECT_CALL(onDisconnectHook_, Call(false)).WillOnce([this]() { subscriptionSource_.stop(); });
 | 
			
		||||
    ioContext_.run();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -438,7 +460,7 @@ TEST_F(SubscriptionSourceReadTests, GotTransactionIsForwardingTrue)
 | 
			
		||||
    });
 | 
			
		||||
 | 
			
		||||
    EXPECT_CALL(onConnectHook_, Call());
 | 
			
		||||
    EXPECT_CALL(onDisconnectHook_, Call()).WillOnce([this]() { subscriptionSource_.stop(); });
 | 
			
		||||
    EXPECT_CALL(onDisconnectHook_, Call(true)).WillOnce([this]() { subscriptionSource_.stop(); });
 | 
			
		||||
    EXPECT_CALL(*subscriptionManager_, forwardProposedTransaction(message));
 | 
			
		||||
    ioContext_.run();
 | 
			
		||||
}
 | 
			
		||||
@@ -454,7 +476,7 @@ TEST_F(SubscriptionSourceReadTests, GotTransactionWithMetaIsForwardingFalse)
 | 
			
		||||
    });
 | 
			
		||||
 | 
			
		||||
    EXPECT_CALL(onConnectHook_, Call());
 | 
			
		||||
    EXPECT_CALL(onDisconnectHook_, Call()).WillOnce([this]() { subscriptionSource_.stop(); });
 | 
			
		||||
    EXPECT_CALL(onDisconnectHook_, Call(true)).WillOnce([this]() { subscriptionSource_.stop(); });
 | 
			
		||||
    EXPECT_CALL(*subscriptionManager_, forwardProposedTransaction(message)).Times(0);
 | 
			
		||||
    ioContext_.run();
 | 
			
		||||
}
 | 
			
		||||
@@ -467,7 +489,7 @@ TEST_F(SubscriptionSourceReadTests, GotValidationReceivedIsForwardingFalse)
 | 
			
		||||
    });
 | 
			
		||||
 | 
			
		||||
    EXPECT_CALL(onConnectHook_, Call());
 | 
			
		||||
    EXPECT_CALL(onDisconnectHook_, Call()).WillOnce([this]() { subscriptionSource_.stop(); });
 | 
			
		||||
    EXPECT_CALL(onDisconnectHook_, Call(false)).WillOnce([this]() { subscriptionSource_.stop(); });
 | 
			
		||||
    ioContext_.run();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -482,7 +504,7 @@ TEST_F(SubscriptionSourceReadTests, GotValidationReceivedIsForwardingTrue)
 | 
			
		||||
    });
 | 
			
		||||
 | 
			
		||||
    EXPECT_CALL(onConnectHook_, Call());
 | 
			
		||||
    EXPECT_CALL(onDisconnectHook_, Call()).WillOnce([this]() { subscriptionSource_.stop(); });
 | 
			
		||||
    EXPECT_CALL(onDisconnectHook_, Call(true)).WillOnce([this]() { subscriptionSource_.stop(); });
 | 
			
		||||
    EXPECT_CALL(*subscriptionManager_, forwardValidation(message));
 | 
			
		||||
    ioContext_.run();
 | 
			
		||||
}
 | 
			
		||||
@@ -495,7 +517,7 @@ TEST_F(SubscriptionSourceReadTests, GotManiefstReceivedIsForwardingFalse)
 | 
			
		||||
    });
 | 
			
		||||
 | 
			
		||||
    EXPECT_CALL(onConnectHook_, Call());
 | 
			
		||||
    EXPECT_CALL(onDisconnectHook_, Call()).WillOnce([this]() { subscriptionSource_.stop(); });
 | 
			
		||||
    EXPECT_CALL(onDisconnectHook_, Call(false)).WillOnce([this]() { subscriptionSource_.stop(); });
 | 
			
		||||
    ioContext_.run();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -510,7 +532,7 @@ TEST_F(SubscriptionSourceReadTests, GotManifestReceivedIsForwardingTrue)
 | 
			
		||||
    });
 | 
			
		||||
 | 
			
		||||
    EXPECT_CALL(onConnectHook_, Call());
 | 
			
		||||
    EXPECT_CALL(onDisconnectHook_, Call()).WillOnce([this]() { subscriptionSource_.stop(); });
 | 
			
		||||
    EXPECT_CALL(onDisconnectHook_, Call(true)).WillOnce([this]() { subscriptionSource_.stop(); });
 | 
			
		||||
    EXPECT_CALL(*subscriptionManager_, forwardManifest(message));
 | 
			
		||||
    ioContext_.run();
 | 
			
		||||
}
 | 
			
		||||
@@ -523,7 +545,7 @@ TEST_F(SubscriptionSourceReadTests, LastMessageTime)
 | 
			
		||||
    });
 | 
			
		||||
 | 
			
		||||
    EXPECT_CALL(onConnectHook_, Call());
 | 
			
		||||
    EXPECT_CALL(onDisconnectHook_, Call()).WillOnce([this]() { subscriptionSource_.stop(); });
 | 
			
		||||
    EXPECT_CALL(onDisconnectHook_, Call(false)).WillOnce([this]() { subscriptionSource_.stop(); });
 | 
			
		||||
    ioContext_.run();
 | 
			
		||||
 | 
			
		||||
    auto const actualLastTimeMessage = subscriptionSource_.lastMessageTime();
 | 
			
		||||
@@ -531,3 +553,27 @@ TEST_F(SubscriptionSourceReadTests, LastMessageTime)
 | 
			
		||||
    auto const diff = std::chrono::duration_cast<std::chrono::milliseconds>(now - actualLastTimeMessage);
 | 
			
		||||
    EXPECT_LT(diff, std::chrono::milliseconds(100));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
struct SubscriptionSourcePrometheusCounterTests : util::prometheus::WithMockPrometheus,
 | 
			
		||||
                                                  SubscriptionSourceReadTestsBase {};
 | 
			
		||||
 | 
			
		||||
TEST_F(SubscriptionSourcePrometheusCounterTests, LastMessageTime)
 | 
			
		||||
{
 | 
			
		||||
    auto& lastMessageTimeMock = makeMock<util::prometheus::GaugeInt>(
 | 
			
		||||
        "subscription_source_last_message_time", fmt::format("{{source=\"127.0.0.1:{}\"}}", wsServer_.port())
 | 
			
		||||
    );
 | 
			
		||||
    boost::asio::spawn(ioContext_, [this](boost::asio::yield_context yield) {
 | 
			
		||||
        auto connection = connectAndSendMessage("some_message", yield);
 | 
			
		||||
        connection.close(yield);
 | 
			
		||||
    });
 | 
			
		||||
 | 
			
		||||
    EXPECT_CALL(onConnectHook_, Call());
 | 
			
		||||
    EXPECT_CALL(onDisconnectHook_, Call(false)).WillOnce([this]() { subscriptionSource_.stop(); });
 | 
			
		||||
    EXPECT_CALL(lastMessageTimeMock, set).WillOnce([](int64_t value) {
 | 
			
		||||
        auto const now =
 | 
			
		||||
            std::chrono::duration_cast<std::chrono::seconds>(std::chrono::system_clock::now().time_since_epoch())
 | 
			
		||||
                .count();
 | 
			
		||||
        EXPECT_LE(now - value, 1);
 | 
			
		||||
    });
 | 
			
		||||
    ioContext_.run();
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -21,6 +21,7 @@
 | 
			
		||||
#include "feed/impl/LedgerFeed.hpp"
 | 
			
		||||
#include "util/TestObject.hpp"
 | 
			
		||||
 | 
			
		||||
#include <boost/asio/io_context.hpp>
 | 
			
		||||
#include <boost/asio/spawn.hpp>
 | 
			
		||||
#include <boost/json/parse.hpp>
 | 
			
		||||
#include <gmock/gmock.h>
 | 
			
		||||
@@ -57,11 +58,13 @@ TEST_F(FeedLedgerTest, SubPub)
 | 
			
		||||
            "reserve_base":3,
 | 
			
		||||
            "reserve_inc":2
 | 
			
		||||
        })";
 | 
			
		||||
    boost::asio::spawn(ctx, [this](boost::asio::yield_context yield) {
 | 
			
		||||
    boost::asio::io_context ioContext;
 | 
			
		||||
    boost::asio::spawn(ioContext, [this](boost::asio::yield_context yield) {
 | 
			
		||||
        auto res = testFeedPtr->sub(yield, backend, sessionPtr);
 | 
			
		||||
        // check the response
 | 
			
		||||
        EXPECT_EQ(res, json::parse(LedgerResponse));
 | 
			
		||||
    });
 | 
			
		||||
    ioContext.run();
 | 
			
		||||
    EXPECT_EQ(testFeedPtr->count(), 1);
 | 
			
		||||
 | 
			
		||||
    constexpr static auto ledgerPub =
 | 
			
		||||
 
 | 
			
		||||
@@ -28,6 +28,7 @@
 | 
			
		||||
#include "util/async/context/SyncExecutionContext.hpp"
 | 
			
		||||
#include "web/interface/ConnectionBase.hpp"
 | 
			
		||||
 | 
			
		||||
#include <boost/asio/io_context.hpp>
 | 
			
		||||
#include <boost/asio/spawn.hpp>
 | 
			
		||||
#include <boost/json/object.hpp>
 | 
			
		||||
#include <boost/json/parse.hpp>
 | 
			
		||||
@@ -57,13 +58,12 @@ class SubscriptionManagerBaseTest : public util::prometheus::WithPrometheus, pub
 | 
			
		||||
protected:
 | 
			
		||||
    std::shared_ptr<SubscriptionManager> subscriptionManagerPtr;
 | 
			
		||||
    std::shared_ptr<web::ConnectionBase> session;
 | 
			
		||||
    Execution ctx{2};
 | 
			
		||||
    MockSession* sessionPtr = nullptr;
 | 
			
		||||
 | 
			
		||||
    void
 | 
			
		||||
    SetUp() override
 | 
			
		||||
    {
 | 
			
		||||
        subscriptionManagerPtr = std::make_shared<SubscriptionManager>(ctx, backend);
 | 
			
		||||
        subscriptionManagerPtr = std::make_shared<SubscriptionManager>(Execution(2), backend);
 | 
			
		||||
        session = std::make_shared<MockSession>();
 | 
			
		||||
        session->apiSubVersion = 1;
 | 
			
		||||
        sessionPtr = dynamic_cast<MockSession*>(session.get());
 | 
			
		||||
@@ -264,11 +264,13 @@ TEST_F(SubscriptionManagerTest, LedgerTest)
 | 
			
		||||
            "reserve_base":3,
 | 
			
		||||
            "reserve_inc":2
 | 
			
		||||
        })";
 | 
			
		||||
    boost::asio::io_context ctx;
 | 
			
		||||
    boost::asio::spawn(ctx, [this](boost::asio::yield_context yield) {
 | 
			
		||||
        auto const res = subscriptionManagerPtr->subLedger(yield, session);
 | 
			
		||||
        // check the response
 | 
			
		||||
        EXPECT_EQ(res, json::parse(LedgerResponse));
 | 
			
		||||
    });
 | 
			
		||||
    ctx.run();
 | 
			
		||||
    EXPECT_EQ(subscriptionManagerPtr->report()["ledger"], 1);
 | 
			
		||||
 | 
			
		||||
    // test publish
 | 
			
		||||
 
 | 
			
		||||
@@ -79,7 +79,7 @@ TEST(RPCErrorsTest, StatusAsBool)
 | 
			
		||||
TEST(RPCErrorsTest, StatusEquals)
 | 
			
		||||
{
 | 
			
		||||
    EXPECT_EQ(Status{RippledError::rpcUNKNOWN}, Status{RippledError::rpcUNKNOWN});
 | 
			
		||||
    EXPECT_NE(Status{RippledError::rpcUNKNOWN}, Status{RippledError::rpcREPORTING_UNSUPPORTED});
 | 
			
		||||
    EXPECT_NE(Status{RippledError::rpcUNKNOWN}, Status{RippledError::rpcINTERNAL});
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST(RPCErrorsTest, SuccessToJSON)
 | 
			
		||||
 
 | 
			
		||||
@@ -23,6 +23,7 @@
 | 
			
		||||
#include "util/MockCounters.hpp"
 | 
			
		||||
#include "util/MockHandlerProvider.hpp"
 | 
			
		||||
#include "util/MockLoadBalancer.hpp"
 | 
			
		||||
#include "util/NameGenerator.hpp"
 | 
			
		||||
#include "util/Taggable.hpp"
 | 
			
		||||
#include "util/config/Config.hpp"
 | 
			
		||||
#include "web/Context.hpp"
 | 
			
		||||
@@ -32,10 +33,12 @@
 | 
			
		||||
#include <gmock/gmock.h>
 | 
			
		||||
#include <gtest/gtest.h>
 | 
			
		||||
 | 
			
		||||
#include <cstdint>
 | 
			
		||||
#include <memory>
 | 
			
		||||
#include <optional>
 | 
			
		||||
#include <string>
 | 
			
		||||
#include <variant>
 | 
			
		||||
#include <vector>
 | 
			
		||||
 | 
			
		||||
using namespace rpc;
 | 
			
		||||
using namespace testing;
 | 
			
		||||
@@ -59,251 +62,159 @@ protected:
 | 
			
		||||
    };
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
TEST_F(RPCForwardingProxyTest, ShouldForwardReturnsFalseIfClioOnly)
 | 
			
		||||
struct ShouldForwardParamTestCaseBundle {
 | 
			
		||||
    std::string testName;
 | 
			
		||||
    std::uint32_t apiVersion;
 | 
			
		||||
    std::string method;
 | 
			
		||||
    std::string testJson;
 | 
			
		||||
    bool mockedIsClioOnly;
 | 
			
		||||
    std::uint32_t called;
 | 
			
		||||
    bool isAdmin;
 | 
			
		||||
    bool expected;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
struct ShouldForwardParameterTest : public RPCForwardingProxyTest,
 | 
			
		||||
                                    WithParamInterface<ShouldForwardParamTestCaseBundle> {};
 | 
			
		||||
 | 
			
		||||
static auto
 | 
			
		||||
generateTestValuesForParametersTest()
 | 
			
		||||
{
 | 
			
		||||
    auto const isClioOnly = true;
 | 
			
		||||
    auto const isAdmin = true;
 | 
			
		||||
    auto const shouldForward = true;
 | 
			
		||||
 | 
			
		||||
    return std::vector<ShouldForwardParamTestCaseBundle>{
 | 
			
		||||
        {"ShouldForwardReturnsFalseIfClioOnly", 2u, "test", "{}", isClioOnly, 1, !isAdmin, !shouldForward},
 | 
			
		||||
        {"ShouldForwardReturnsTrueIfProxied", 2u, "submit", "{}", !isClioOnly, 1, !isAdmin, shouldForward},
 | 
			
		||||
        {"ShouldForwardReturnsTrueIfCurrentLedgerSpecified",
 | 
			
		||||
         2u,
 | 
			
		||||
         "anymethod",
 | 
			
		||||
         R"({"ledger_index": "current"})",
 | 
			
		||||
         !isClioOnly,
 | 
			
		||||
         1,
 | 
			
		||||
         !isAdmin,
 | 
			
		||||
         shouldForward},
 | 
			
		||||
        {"ShouldForwardReturnsTrueIfClosedLedgerSpecified",
 | 
			
		||||
         2u,
 | 
			
		||||
         "anymethod",
 | 
			
		||||
         R"({"ledger_index": "closed"})",
 | 
			
		||||
         !isClioOnly,
 | 
			
		||||
         1,
 | 
			
		||||
         !isAdmin,
 | 
			
		||||
         shouldForward},
 | 
			
		||||
        {"ShouldForwardReturnsTrueIfAccountInfoWithQueueSpecified",
 | 
			
		||||
         2u,
 | 
			
		||||
         "account_info",
 | 
			
		||||
         R"({"queue": true})",
 | 
			
		||||
         !isClioOnly,
 | 
			
		||||
         1,
 | 
			
		||||
         !isAdmin,
 | 
			
		||||
         shouldForward},
 | 
			
		||||
        {"ShouldForwardReturnsFalseIfAccountInfoQueueIsFalse",
 | 
			
		||||
         2u,
 | 
			
		||||
         "account_info",
 | 
			
		||||
         R"({"queue": false})",
 | 
			
		||||
         !isClioOnly,
 | 
			
		||||
         1,
 | 
			
		||||
         !isAdmin,
 | 
			
		||||
         !shouldForward},
 | 
			
		||||
        {"ShouldForwardReturnsTrueIfLedgerWithQueueSpecified",
 | 
			
		||||
         2u,
 | 
			
		||||
         "ledger",
 | 
			
		||||
         R"({"queue": true})",
 | 
			
		||||
         !isClioOnly,
 | 
			
		||||
         1,
 | 
			
		||||
         !isAdmin,
 | 
			
		||||
         shouldForward},
 | 
			
		||||
        {"ShouldForwardReturnsFalseIfLedgerQueueIsFalse",
 | 
			
		||||
         2u,
 | 
			
		||||
         "ledger",
 | 
			
		||||
         R"({"queue": false})",
 | 
			
		||||
         !isClioOnly,
 | 
			
		||||
         1,
 | 
			
		||||
         !isAdmin,
 | 
			
		||||
         !shouldForward},
 | 
			
		||||
        {"ShouldNotForwardReturnsTrueIfAPIVersionIsV1",
 | 
			
		||||
         1u,
 | 
			
		||||
         "api_version_check",
 | 
			
		||||
         "{}",
 | 
			
		||||
         !isClioOnly,
 | 
			
		||||
         1,
 | 
			
		||||
         !isAdmin,
 | 
			
		||||
         !shouldForward},
 | 
			
		||||
        {"ShouldForwardReturnsFalseIfAPIVersionIsV2",
 | 
			
		||||
         2u,
 | 
			
		||||
         "api_version_check",
 | 
			
		||||
         "{}",
 | 
			
		||||
         !isClioOnly,
 | 
			
		||||
         1,
 | 
			
		||||
         !isAdmin,
 | 
			
		||||
         !shouldForward},
 | 
			
		||||
        {"ShouldNeverForwardSubscribe", 1u, "subscribe", "{}", !isClioOnly, 0, !isAdmin, !shouldForward},
 | 
			
		||||
        {"ShouldNeverForwardUnsubscribe", 1u, "unsubscribe", "{}", !isClioOnly, 0, !isAdmin, !shouldForward},
 | 
			
		||||
        {"ForceForwardTrue", 1u, "any_method", R"({"force_forward": true})", !isClioOnly, 1, isAdmin, shouldForward},
 | 
			
		||||
        {"ForceForwardFalse", 1u, "any_method", R"({"force_forward": false})", !isClioOnly, 1, isAdmin, !shouldForward},
 | 
			
		||||
        {"ForceForwardNotAdmin",
 | 
			
		||||
         1u,
 | 
			
		||||
         "any_method",
 | 
			
		||||
         R"({"force_forward": true})",
 | 
			
		||||
         !isClioOnly,
 | 
			
		||||
         1,
 | 
			
		||||
         !isAdmin,
 | 
			
		||||
         !shouldForward},
 | 
			
		||||
        {"ForceForwardSubscribe",
 | 
			
		||||
         1u,
 | 
			
		||||
         "subscribe",
 | 
			
		||||
         R"({"force_forward": true})",
 | 
			
		||||
         !isClioOnly,
 | 
			
		||||
         0,
 | 
			
		||||
         isAdmin,
 | 
			
		||||
         not shouldForward},
 | 
			
		||||
        {"ForceForwardUnsubscribe",
 | 
			
		||||
         1u,
 | 
			
		||||
         "unsubscribe",
 | 
			
		||||
         R"({"force_forward": true})",
 | 
			
		||||
         !isClioOnly,
 | 
			
		||||
         0,
 | 
			
		||||
         isAdmin,
 | 
			
		||||
         !shouldForward},
 | 
			
		||||
        {"ForceForwardClioOnly",
 | 
			
		||||
         1u,
 | 
			
		||||
         "clio_only_method",
 | 
			
		||||
         R"({"force_forward": true})",
 | 
			
		||||
         isClioOnly,
 | 
			
		||||
         1,
 | 
			
		||||
         isAdmin,
 | 
			
		||||
         !shouldForward},
 | 
			
		||||
    };
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
INSTANTIATE_TEST_CASE_P(
 | 
			
		||||
    ShouldForwardTest,
 | 
			
		||||
    ShouldForwardParameterTest,
 | 
			
		||||
    ValuesIn(generateTestValuesForParametersTest()),
 | 
			
		||||
    tests::util::NameGenerator
 | 
			
		||||
);
 | 
			
		||||
 | 
			
		||||
TEST_P(ShouldForwardParameterTest, Test)
 | 
			
		||||
{
 | 
			
		||||
    auto const testBundle = GetParam();
 | 
			
		||||
    auto const rawHandlerProviderPtr = handlerProvider.get();
 | 
			
		||||
    auto const apiVersion = 2u;
 | 
			
		||||
    auto const method = "test";
 | 
			
		||||
    auto const params = json::parse("{}");
 | 
			
		||||
    auto const apiVersion = testBundle.apiVersion;
 | 
			
		||||
    auto const method = testBundle.method;
 | 
			
		||||
    auto const params = json::parse(testBundle.testJson);
 | 
			
		||||
 | 
			
		||||
    ON_CALL(*rawHandlerProviderPtr, isClioOnly(_)).WillByDefault(Return(true));
 | 
			
		||||
    EXPECT_CALL(*rawHandlerProviderPtr, isClioOnly(method)).Times(1);
 | 
			
		||||
    ON_CALL(*rawHandlerProviderPtr, isClioOnly(_)).WillByDefault(Return(testBundle.mockedIsClioOnly));
 | 
			
		||||
    EXPECT_CALL(*rawHandlerProviderPtr, isClioOnly(method)).Times(testBundle.called);
 | 
			
		||||
 | 
			
		||||
    runSpawn([&](auto yield) {
 | 
			
		||||
        auto const range = backend->fetchLedgerRange();
 | 
			
		||||
        auto const ctx =
 | 
			
		||||
            web::Context(yield, method, apiVersion, params.as_object(), nullptr, tagFactory, *range, CLIENT_IP, true);
 | 
			
		||||
        auto const ctx = web::Context(
 | 
			
		||||
            yield, method, apiVersion, params.as_object(), nullptr, tagFactory, *range, CLIENT_IP, testBundle.isAdmin
 | 
			
		||||
        );
 | 
			
		||||
 | 
			
		||||
        auto const res = proxy.shouldForward(ctx);
 | 
			
		||||
        ASSERT_FALSE(res);
 | 
			
		||||
    });
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(RPCForwardingProxyTest, ShouldForwardReturnsTrueIfProxied)
 | 
			
		||||
{
 | 
			
		||||
    auto const rawHandlerProviderPtr = handlerProvider.get();
 | 
			
		||||
    auto const apiVersion = 2u;
 | 
			
		||||
    auto const method = "submit";
 | 
			
		||||
    auto const params = json::parse("{}");
 | 
			
		||||
 | 
			
		||||
    ON_CALL(*rawHandlerProviderPtr, isClioOnly(_)).WillByDefault(Return(false));
 | 
			
		||||
    EXPECT_CALL(*rawHandlerProviderPtr, isClioOnly(method)).Times(1);
 | 
			
		||||
 | 
			
		||||
    runSpawn([&](auto yield) {
 | 
			
		||||
        auto const range = backend->fetchLedgerRange();
 | 
			
		||||
        auto const ctx =
 | 
			
		||||
            web::Context(yield, method, apiVersion, params.as_object(), nullptr, tagFactory, *range, CLIENT_IP, true);
 | 
			
		||||
 | 
			
		||||
        auto const res = proxy.shouldForward(ctx);
 | 
			
		||||
        ASSERT_TRUE(res);
 | 
			
		||||
    });
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(RPCForwardingProxyTest, ShouldForwardReturnsTrueIfCurrentLedgerSpecified)
 | 
			
		||||
{
 | 
			
		||||
    auto const rawHandlerProviderPtr = handlerProvider.get();
 | 
			
		||||
    auto const apiVersion = 2u;
 | 
			
		||||
    auto const method = "anymethod";
 | 
			
		||||
    auto const params = json::parse(R"({"ledger_index": "current"})");
 | 
			
		||||
 | 
			
		||||
    ON_CALL(*rawHandlerProviderPtr, isClioOnly(_)).WillByDefault(Return(false));
 | 
			
		||||
    EXPECT_CALL(*rawHandlerProviderPtr, isClioOnly(method)).Times(1);
 | 
			
		||||
 | 
			
		||||
    runSpawn([&](auto yield) {
 | 
			
		||||
        auto const range = backend->fetchLedgerRange();
 | 
			
		||||
        auto const ctx =
 | 
			
		||||
            web::Context(yield, method, apiVersion, params.as_object(), nullptr, tagFactory, *range, CLIENT_IP, true);
 | 
			
		||||
 | 
			
		||||
        auto const res = proxy.shouldForward(ctx);
 | 
			
		||||
        ASSERT_TRUE(res);
 | 
			
		||||
    });
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(RPCForwardingProxyTest, ShouldForwardReturnsTrueIfClosedLedgerSpecified)
 | 
			
		||||
{
 | 
			
		||||
    auto const rawHandlerProviderPtr = handlerProvider.get();
 | 
			
		||||
    auto const apiVersion = 2u;
 | 
			
		||||
    auto const method = "anymethod";
 | 
			
		||||
    auto const params = json::parse(R"({"ledger_index": "closed"})");
 | 
			
		||||
 | 
			
		||||
    ON_CALL(*rawHandlerProviderPtr, isClioOnly(_)).WillByDefault(Return(false));
 | 
			
		||||
    EXPECT_CALL(*rawHandlerProviderPtr, isClioOnly(method)).Times(1);
 | 
			
		||||
 | 
			
		||||
    runSpawn([&](auto yield) {
 | 
			
		||||
        auto const range = backend->fetchLedgerRange();
 | 
			
		||||
        auto const ctx =
 | 
			
		||||
            web::Context(yield, method, apiVersion, params.as_object(), nullptr, tagFactory, *range, CLIENT_IP, true);
 | 
			
		||||
 | 
			
		||||
        auto const res = proxy.shouldForward(ctx);
 | 
			
		||||
        ASSERT_TRUE(res);
 | 
			
		||||
    });
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(RPCForwardingProxyTest, ShouldForwardReturnsTrueIfAccountInfoWithQueueSpecified)
 | 
			
		||||
{
 | 
			
		||||
    auto const rawHandlerProviderPtr = handlerProvider.get();
 | 
			
		||||
    auto const apiVersion = 2u;
 | 
			
		||||
    auto const method = "account_info";
 | 
			
		||||
    auto const params = json::parse(R"({"queue": true})");
 | 
			
		||||
 | 
			
		||||
    ON_CALL(*rawHandlerProviderPtr, isClioOnly(_)).WillByDefault(Return(false));
 | 
			
		||||
    EXPECT_CALL(*rawHandlerProviderPtr, isClioOnly(method)).Times(1);
 | 
			
		||||
 | 
			
		||||
    runSpawn([&](auto yield) {
 | 
			
		||||
        auto const range = backend->fetchLedgerRange();
 | 
			
		||||
        auto const ctx =
 | 
			
		||||
            web::Context(yield, method, apiVersion, params.as_object(), nullptr, tagFactory, *range, CLIENT_IP, true);
 | 
			
		||||
 | 
			
		||||
        auto const res = proxy.shouldForward(ctx);
 | 
			
		||||
        ASSERT_TRUE(res);
 | 
			
		||||
    });
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(RPCForwardingProxyTest, ShouldForwardReturnsFalseIfAccountInfoQueueIsFalse)
 | 
			
		||||
{
 | 
			
		||||
    auto const rawHandlerProviderPtr = handlerProvider.get();
 | 
			
		||||
    auto const apiVersion = 2u;
 | 
			
		||||
    auto const method = "account_info";
 | 
			
		||||
    auto const params = json::parse(R"({"queue": false})");
 | 
			
		||||
 | 
			
		||||
    ON_CALL(*rawHandlerProviderPtr, isClioOnly(_)).WillByDefault(Return(false));
 | 
			
		||||
    EXPECT_CALL(*rawHandlerProviderPtr, isClioOnly(method)).Times(1);
 | 
			
		||||
 | 
			
		||||
    runSpawn([&](auto yield) {
 | 
			
		||||
        auto const range = backend->fetchLedgerRange();
 | 
			
		||||
        auto const ctx =
 | 
			
		||||
            web::Context(yield, method, apiVersion, params.as_object(), nullptr, tagFactory, *range, CLIENT_IP, true);
 | 
			
		||||
 | 
			
		||||
        auto const res = proxy.shouldForward(ctx);
 | 
			
		||||
        ASSERT_FALSE(res);
 | 
			
		||||
    });
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(RPCForwardingProxyTest, ShouldForwardReturnsTrueIfLedgerWithQueueSpecified)
 | 
			
		||||
{
 | 
			
		||||
    auto const rawHandlerProviderPtr = handlerProvider.get();
 | 
			
		||||
    auto const apiVersion = 2u;
 | 
			
		||||
    auto const method = "ledger";
 | 
			
		||||
    auto const params = json::parse(R"({"queue": true})");
 | 
			
		||||
 | 
			
		||||
    ON_CALL(*rawHandlerProviderPtr, isClioOnly(_)).WillByDefault(Return(false));
 | 
			
		||||
    EXPECT_CALL(*rawHandlerProviderPtr, isClioOnly(method)).Times(1);
 | 
			
		||||
 | 
			
		||||
    runSpawn([&](auto yield) {
 | 
			
		||||
        auto const range = backend->fetchLedgerRange();
 | 
			
		||||
        auto const ctx =
 | 
			
		||||
            web::Context(yield, method, apiVersion, params.as_object(), nullptr, tagFactory, *range, CLIENT_IP, true);
 | 
			
		||||
 | 
			
		||||
        auto const res = proxy.shouldForward(ctx);
 | 
			
		||||
        ASSERT_TRUE(res);
 | 
			
		||||
    });
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(RPCForwardingProxyTest, ShouldForwardReturnsFalseIfLedgerQueueIsFalse)
 | 
			
		||||
{
 | 
			
		||||
    auto const rawHandlerProviderPtr = handlerProvider.get();
 | 
			
		||||
    auto const apiVersion = 2u;
 | 
			
		||||
    auto const method = "ledger";
 | 
			
		||||
    auto const params = json::parse(R"({"queue": false})");
 | 
			
		||||
 | 
			
		||||
    ON_CALL(*rawHandlerProviderPtr, isClioOnly(_)).WillByDefault(Return(false));
 | 
			
		||||
    EXPECT_CALL(*rawHandlerProviderPtr, isClioOnly(method)).Times(1);
 | 
			
		||||
 | 
			
		||||
    runSpawn([&](auto yield) {
 | 
			
		||||
        auto const range = backend->fetchLedgerRange();
 | 
			
		||||
        auto const ctx =
 | 
			
		||||
            web::Context(yield, method, apiVersion, params.as_object(), nullptr, tagFactory, *range, CLIENT_IP, true);
 | 
			
		||||
 | 
			
		||||
        auto const res = proxy.shouldForward(ctx);
 | 
			
		||||
        ASSERT_FALSE(res);
 | 
			
		||||
    });
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(RPCForwardingProxyTest, ShouldNotForwardReturnsTrueIfAPIVersionIsV1)
 | 
			
		||||
{
 | 
			
		||||
    auto const apiVersion = 1u;
 | 
			
		||||
    auto const method = "api_version_check";
 | 
			
		||||
    auto const params = json::parse("{}");
 | 
			
		||||
 | 
			
		||||
    auto const rawHandlerProviderPtr = handlerProvider.get();
 | 
			
		||||
    ON_CALL(*rawHandlerProviderPtr, isClioOnly(_)).WillByDefault(Return(false));
 | 
			
		||||
    EXPECT_CALL(*rawHandlerProviderPtr, isClioOnly(method)).Times(1);
 | 
			
		||||
 | 
			
		||||
    runSpawn([&](auto yield) {
 | 
			
		||||
        auto const range = backend->fetchLedgerRange();
 | 
			
		||||
        auto const ctx =
 | 
			
		||||
            web::Context(yield, method, apiVersion, params.as_object(), nullptr, tagFactory, *range, CLIENT_IP, true);
 | 
			
		||||
 | 
			
		||||
        auto const res = proxy.shouldForward(ctx);
 | 
			
		||||
        ASSERT_FALSE(res);
 | 
			
		||||
    });
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(RPCForwardingProxyTest, ShouldForwardReturnsFalseIfAPIVersionIsV2)
 | 
			
		||||
{
 | 
			
		||||
    auto const rawHandlerProviderPtr = handlerProvider.get();
 | 
			
		||||
    auto const apiVersion = 2u;
 | 
			
		||||
    auto const method = "api_version_check";
 | 
			
		||||
    auto const params = json::parse("{}");
 | 
			
		||||
 | 
			
		||||
    ON_CALL(*rawHandlerProviderPtr, isClioOnly(_)).WillByDefault(Return(false));
 | 
			
		||||
    EXPECT_CALL(*rawHandlerProviderPtr, isClioOnly(method)).Times(1);
 | 
			
		||||
 | 
			
		||||
    runSpawn([&](auto yield) {
 | 
			
		||||
        auto const range = backend->fetchLedgerRange();
 | 
			
		||||
        auto const ctx =
 | 
			
		||||
            web::Context(yield, method, apiVersion, params.as_object(), nullptr, tagFactory, *range, CLIENT_IP, true);
 | 
			
		||||
 | 
			
		||||
        auto const res = proxy.shouldForward(ctx);
 | 
			
		||||
        ASSERT_FALSE(res);
 | 
			
		||||
    });
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(RPCForwardingProxyTest, ShouldNeverForwardFeatureWithVetoedFlag)
 | 
			
		||||
{
 | 
			
		||||
    auto const apiVersion = 1u;
 | 
			
		||||
    auto const method = "feature";
 | 
			
		||||
    auto const params = json::parse(R"({"vetoed": true, "feature": "foo"})");
 | 
			
		||||
 | 
			
		||||
    runSpawn([&](auto yield) {
 | 
			
		||||
        auto const range = backend->fetchLedgerRange();
 | 
			
		||||
        auto const ctx =
 | 
			
		||||
            web::Context(yield, method, apiVersion, params.as_object(), nullptr, tagFactory, *range, CLIENT_IP, true);
 | 
			
		||||
 | 
			
		||||
        auto const res = proxy.shouldForward(ctx);
 | 
			
		||||
        ASSERT_FALSE(res);
 | 
			
		||||
    });
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(RPCForwardingProxyTest, ShouldNeverForwardSubscribe)
 | 
			
		||||
{
 | 
			
		||||
    auto const apiVersion = 1u;
 | 
			
		||||
    auto const method = "subscribe";
 | 
			
		||||
    auto const params = json::parse("{}");
 | 
			
		||||
 | 
			
		||||
    runSpawn([&](auto yield) {
 | 
			
		||||
        auto const range = backend->fetchLedgerRange();
 | 
			
		||||
        auto const ctx =
 | 
			
		||||
            web::Context(yield, method, apiVersion, params.as_object(), nullptr, tagFactory, *range, CLIENT_IP, true);
 | 
			
		||||
 | 
			
		||||
        auto const res = proxy.shouldForward(ctx);
 | 
			
		||||
        ASSERT_FALSE(res);
 | 
			
		||||
    });
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(RPCForwardingProxyTest, ShouldNeverForwardUnsubscribe)
 | 
			
		||||
{
 | 
			
		||||
    auto const apiVersion = 1u;
 | 
			
		||||
    auto const method = "unsubscribe";
 | 
			
		||||
    auto const params = json::parse("{}");
 | 
			
		||||
 | 
			
		||||
    runSpawn([&](auto yield) {
 | 
			
		||||
        auto const range = backend->fetchLedgerRange();
 | 
			
		||||
        auto const ctx =
 | 
			
		||||
            web::Context(yield, method, apiVersion, params.as_object(), nullptr, tagFactory, *range, CLIENT_IP, true);
 | 
			
		||||
 | 
			
		||||
        auto const res = proxy.shouldForward(ctx);
 | 
			
		||||
        ASSERT_FALSE(res);
 | 
			
		||||
        ASSERT_EQ(res, testBundle.expected);
 | 
			
		||||
    });
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -25,6 +25,7 @@
 | 
			
		||||
#include "util/AsioContextTestFixture.hpp"
 | 
			
		||||
#include "util/MockBackendTestFixture.hpp"
 | 
			
		||||
#include "util/MockPrometheus.hpp"
 | 
			
		||||
#include "util/NameGenerator.hpp"
 | 
			
		||||
#include "util/TestObject.hpp"
 | 
			
		||||
 | 
			
		||||
#include <boost/asio/impl/spawn.hpp>
 | 
			
		||||
@@ -539,3 +540,67 @@ TEST_F(RPCHelpersTest, ParseIssue)
 | 
			
		||||
        std::runtime_error
 | 
			
		||||
    );
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
struct IsAdminCmdParamTestCaseBundle {
 | 
			
		||||
    std::string testName;
 | 
			
		||||
    std::string method;
 | 
			
		||||
    std::string testJson;
 | 
			
		||||
    bool expected;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
struct IsAdminCmdParameterTest : public TestWithParam<IsAdminCmdParamTestCaseBundle> {};
 | 
			
		||||
 | 
			
		||||
static auto
 | 
			
		||||
generateTestValuesForParametersTest()
 | 
			
		||||
{
 | 
			
		||||
    return std::vector<IsAdminCmdParamTestCaseBundle>{
 | 
			
		||||
        {"ledgerEntry", "ledger_entry", R"({"type": false})", false},
 | 
			
		||||
 | 
			
		||||
        {"featureVetoedTrue", "feature", R"({"vetoed": true, "feature": "foo"})", true},
 | 
			
		||||
        {"featureVetoedFalse", "feature", R"({"vetoed": false, "feature": "foo"})", true},
 | 
			
		||||
        {"featureVetoedIsStr", "feature", R"({"vetoed": "String"})", true},
 | 
			
		||||
 | 
			
		||||
        {"ledger", "ledger", R"({})", false},
 | 
			
		||||
        {"ledgerWithType", "ledger", R"({"type": "fee"})", false},
 | 
			
		||||
        {"ledgerFullTrue", "ledger", R"({"full": true})", true},
 | 
			
		||||
        {"ledgerFullFalse", "ledger", R"({"full": false})", false},
 | 
			
		||||
        {"ledgerFullIsStr", "ledger", R"({"full": "String"})", true},
 | 
			
		||||
        {"ledgerFullIsEmptyStr", "ledger", R"({"full": ""})", false},
 | 
			
		||||
        {"ledgerFullIsNumber1", "ledger", R"({"full": 1})", true},
 | 
			
		||||
        {"ledgerFullIsNumber0", "ledger", R"({"full": 0})", false},
 | 
			
		||||
        {"ledgerFullIsNull", "ledger", R"({"full": null})", false},
 | 
			
		||||
        {"ledgerFullIsFloat0", "ledger", R"({"full": 0.0})", false},
 | 
			
		||||
        {"ledgerFullIsFloat1", "ledger", R"({"full": 0.1})", true},
 | 
			
		||||
        {"ledgerFullIsArray", "ledger", R"({"full": [1]})", true},
 | 
			
		||||
        {"ledgerFullIsEmptyArray", "ledger", R"({"full": []})", false},
 | 
			
		||||
        {"ledgerFullIsObject", "ledger", R"({"full": {"key": 1}})", true},
 | 
			
		||||
        {"ledgerFullIsEmptyObject", "ledger", R"({"full": {}})", false},
 | 
			
		||||
 | 
			
		||||
        {"ledgerAccountsTrue", "ledger", R"({"accounts": true})", true},
 | 
			
		||||
        {"ledgerAccountsFalse", "ledger", R"({"accounts": false})", false},
 | 
			
		||||
        {"ledgerAccountsIsStr", "ledger", R"({"accounts": "String"})", true},
 | 
			
		||||
        {"ledgerAccountsIsEmptyStr", "ledger", R"({"accounts": ""})", false},
 | 
			
		||||
        {"ledgerAccountsIsNumber1", "ledger", R"({"accounts": 1})", true},
 | 
			
		||||
        {"ledgerAccountsIsNumber0", "ledger", R"({"accounts": 0})", false},
 | 
			
		||||
        {"ledgerAccountsIsNull", "ledger", R"({"accounts": null})", false},
 | 
			
		||||
        {"ledgerAccountsIsFloat0", "ledger", R"({"accounts": 0.0})", false},
 | 
			
		||||
        {"ledgerAccountsIsFloat1", "ledger", R"({"accounts": 0.1})", true},
 | 
			
		||||
        {"ledgerAccountsIsArray", "ledger", R"({"accounts": [1]})", true},
 | 
			
		||||
        {"ledgerAccountsIsEmptyArray", "ledger", R"({"accounts": []})", false},
 | 
			
		||||
        {"ledgerAccountsIsObject", "ledger", R"({"accounts": {"key": 1}})", true},
 | 
			
		||||
        {"ledgerAccountsIsEmptyObject", "ledger", R"({"accounts": {}})", false},
 | 
			
		||||
    };
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
INSTANTIATE_TEST_CASE_P(
 | 
			
		||||
    IsAdminCmdTest,
 | 
			
		||||
    IsAdminCmdParameterTest,
 | 
			
		||||
    ValuesIn(generateTestValuesForParametersTest()),
 | 
			
		||||
    tests::util::NameGenerator
 | 
			
		||||
);
 | 
			
		||||
 | 
			
		||||
TEST_P(IsAdminCmdParameterTest, Test)
 | 
			
		||||
{
 | 
			
		||||
    auto const testBundle = GetParam();
 | 
			
		||||
    EXPECT_EQ(isAdminCmd(testBundle.method, boost::json::parse(testBundle.testJson).as_object()), testBundle.expected);
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -49,6 +49,7 @@ constexpr static auto TAXON = 0;
 | 
			
		||||
constexpr static auto FLAG = 8;
 | 
			
		||||
constexpr static auto TXNID = "E6DBAFC99223B42257915A63DFC6B0C032D4070F9A574B255AD97466726FC321";
 | 
			
		||||
constexpr static auto PAGE = "E6DBAFC99223B42257915A63DFC6B0C032D4070F9A574B255AD97466726FC322";
 | 
			
		||||
constexpr static auto INVALIDPAGE = "E6DBAFC99223B42257915A63DFC6B0C032D4070F9A574B255AD97466726FCAAA";
 | 
			
		||||
constexpr static auto MAXSEQ = 30;
 | 
			
		||||
constexpr static auto MINSEQ = 10;
 | 
			
		||||
 | 
			
		||||
@@ -402,6 +403,98 @@ TEST_F(RPCAccountNFTsHandlerTest, Marker)
 | 
			
		||||
    });
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(RPCAccountNFTsHandlerTest, InvalidMarker)
 | 
			
		||||
{
 | 
			
		||||
    backend->setRange(MINSEQ, MAXSEQ);
 | 
			
		||||
    auto const ledgerHeader = CreateLedgerHeader(LEDGERHASH, MAXSEQ);
 | 
			
		||||
    EXPECT_CALL(*backend, fetchLedgerBySequence).Times(1);
 | 
			
		||||
    ON_CALL(*backend, fetchLedgerBySequence).WillByDefault(Return(ledgerHeader));
 | 
			
		||||
 | 
			
		||||
    auto const accountObject = CreateAccountRootObject(ACCOUNT, 0, 1, 10, 2, TXNID, 3);
 | 
			
		||||
    auto const accountID = GetAccountIDWithString(ACCOUNT);
 | 
			
		||||
    ON_CALL(*backend, doFetchLedgerObject(ripple::keylet::account(accountID).key, 30, _))
 | 
			
		||||
        .WillByDefault(Return(accountObject.getSerializer().peekData()));
 | 
			
		||||
 | 
			
		||||
    auto static const input = json::parse(fmt::format(
 | 
			
		||||
        R"({{
 | 
			
		||||
            "account":"{}",
 | 
			
		||||
            "marker":"{}"
 | 
			
		||||
        }})",
 | 
			
		||||
        ACCOUNT,
 | 
			
		||||
        INVALIDPAGE
 | 
			
		||||
    ));
 | 
			
		||||
    auto const handler = AnyHandler{AccountNFTsHandler{backend}};
 | 
			
		||||
    runSpawn([&](auto yield) {
 | 
			
		||||
        auto const output = handler.process(input, Context{yield});
 | 
			
		||||
        ASSERT_FALSE(output);
 | 
			
		||||
        auto const err = rpc::makeError(output.result.error());
 | 
			
		||||
        EXPECT_EQ(err.at("error").as_string(), "invalidParams");
 | 
			
		||||
        EXPECT_EQ(err.at("error_message").as_string(), "Marker field does not match any valid Page ID");
 | 
			
		||||
    });
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(RPCAccountNFTsHandlerTest, AccountWithNoNFT)
 | 
			
		||||
{
 | 
			
		||||
    backend->setRange(MINSEQ, MAXSEQ);
 | 
			
		||||
    auto const ledgerHeader = CreateLedgerHeader(LEDGERHASH, MAXSEQ);
 | 
			
		||||
    EXPECT_CALL(*backend, fetchLedgerBySequence).Times(1);
 | 
			
		||||
    ON_CALL(*backend, fetchLedgerBySequence).WillByDefault(Return(ledgerHeader));
 | 
			
		||||
 | 
			
		||||
    auto const accountObject = CreateAccountRootObject(ACCOUNT, 0, 1, 10, 2, TXNID, 3);
 | 
			
		||||
    auto const accountID = GetAccountIDWithString(ACCOUNT);
 | 
			
		||||
    ON_CALL(*backend, doFetchLedgerObject(ripple::keylet::account(accountID).key, 30, _))
 | 
			
		||||
        .WillByDefault(Return(accountObject.getSerializer().peekData()));
 | 
			
		||||
 | 
			
		||||
    auto static const input = json::parse(fmt::format(
 | 
			
		||||
        R"({{
 | 
			
		||||
            "account":"{}"
 | 
			
		||||
        }})",
 | 
			
		||||
        ACCOUNT
 | 
			
		||||
    ));
 | 
			
		||||
    auto const handler = AnyHandler{AccountNFTsHandler{backend}};
 | 
			
		||||
    runSpawn([&](auto yield) {
 | 
			
		||||
        auto const output = handler.process(input, Context{yield});
 | 
			
		||||
        ASSERT_TRUE(output);
 | 
			
		||||
        EXPECT_EQ(output.result->as_object().at("account_nfts").as_array().size(), 0);
 | 
			
		||||
    });
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(RPCAccountNFTsHandlerTest, invalidPage)
 | 
			
		||||
{
 | 
			
		||||
    backend->setRange(MINSEQ, MAXSEQ);
 | 
			
		||||
    auto const ledgerHeader = CreateLedgerHeader(LEDGERHASH, MAXSEQ);
 | 
			
		||||
    EXPECT_CALL(*backend, fetchLedgerBySequence).Times(1);
 | 
			
		||||
    ON_CALL(*backend, fetchLedgerBySequence).WillByDefault(Return(ledgerHeader));
 | 
			
		||||
 | 
			
		||||
    auto const accountObject = CreateAccountRootObject(ACCOUNT, 0, 1, 10, 2, TXNID, 3);
 | 
			
		||||
    auto const accountID = GetAccountIDWithString(ACCOUNT);
 | 
			
		||||
    ON_CALL(*backend, doFetchLedgerObject(ripple::keylet::account(accountID).key, 30, _))
 | 
			
		||||
        .WillByDefault(Return(accountObject.getSerializer().peekData()));
 | 
			
		||||
 | 
			
		||||
    auto const pageObject =
 | 
			
		||||
        CreateNFTTokenPage(std::vector{std::make_pair<std::string, std::string>(TOKENID, "www.ok.com")}, std::nullopt);
 | 
			
		||||
    ON_CALL(*backend, doFetchLedgerObject(ripple::uint256{PAGE}, 30, _))
 | 
			
		||||
        .WillByDefault(Return(accountObject.getSerializer().peekData()));
 | 
			
		||||
    EXPECT_CALL(*backend, doFetchLedgerObject).Times(2);
 | 
			
		||||
 | 
			
		||||
    auto static const input = json::parse(fmt::format(
 | 
			
		||||
        R"({{
 | 
			
		||||
            "account":"{}",
 | 
			
		||||
            "marker":"{}"
 | 
			
		||||
        }})",
 | 
			
		||||
        ACCOUNT,
 | 
			
		||||
        PAGE
 | 
			
		||||
    ));
 | 
			
		||||
    auto const handler = AnyHandler{AccountNFTsHandler{backend}};
 | 
			
		||||
    runSpawn([&](auto yield) {
 | 
			
		||||
        auto const output = handler.process(input, Context{yield});
 | 
			
		||||
        ASSERT_FALSE(output);
 | 
			
		||||
        auto const err = rpc::makeError(output.result.error());
 | 
			
		||||
        EXPECT_EQ(err.at("error").as_string(), "invalidParams");
 | 
			
		||||
        EXPECT_EQ(err.at("error_message").as_string(), "Marker matches Page ID from another Account");
 | 
			
		||||
    });
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(RPCAccountNFTsHandlerTest, LimitLessThanMin)
 | 
			
		||||
{
 | 
			
		||||
    static auto const expectedOutput = fmt::format(
 | 
			
		||||
 
 | 
			
		||||
@@ -136,22 +136,13 @@ generateTestValuesForParametersTest()
 | 
			
		||||
        SubscribeParamTestCaseBundle{"StreamNotString", R"({"streams": [1]})", "invalidParams", "streamNotString"},
 | 
			
		||||
        SubscribeParamTestCaseBundle{"StreamNotValid", R"({"streams": ["1"]})", "malformedStream", "Stream malformed."},
 | 
			
		||||
        SubscribeParamTestCaseBundle{
 | 
			
		||||
            "StreamPeerStatusNotSupport",
 | 
			
		||||
            R"({"streams": ["peer_status"]})",
 | 
			
		||||
            "reportingUnsupported",
 | 
			
		||||
            "Requested operation not supported by reporting mode server"
 | 
			
		||||
            "StreamPeerStatusNotSupport", R"({"streams": ["peer_status"]})", "notSupported", "Operation not supported."
 | 
			
		||||
        },
 | 
			
		||||
        SubscribeParamTestCaseBundle{
 | 
			
		||||
            "StreamConsensusNotSupport",
 | 
			
		||||
            R"({"streams": ["consensus"]})",
 | 
			
		||||
            "reportingUnsupported",
 | 
			
		||||
            "Requested operation not supported by reporting mode server"
 | 
			
		||||
            "StreamConsensusNotSupport", R"({"streams": ["consensus"]})", "notSupported", "Operation not supported."
 | 
			
		||||
        },
 | 
			
		||||
        SubscribeParamTestCaseBundle{
 | 
			
		||||
            "StreamServerNotSupport",
 | 
			
		||||
            R"({"streams": ["server"]})",
 | 
			
		||||
            "reportingUnsupported",
 | 
			
		||||
            "Requested operation not supported by reporting mode server"
 | 
			
		||||
            "StreamServerNotSupport", R"({"streams": ["server"]})", "notSupported", "Operation not supported."
 | 
			
		||||
        },
 | 
			
		||||
        SubscribeParamTestCaseBundle{"BooksNotArray", R"({"books": "1"})", "invalidParams", "booksNotArray"},
 | 
			
		||||
        SubscribeParamTestCaseBundle{
 | 
			
		||||
 
 | 
			
		||||
@@ -486,22 +486,13 @@ generateTestValuesForParametersTest()
 | 
			
		||||
            "bothNotBool"
 | 
			
		||||
        },
 | 
			
		||||
        UnsubscribeParamTestCaseBundle{
 | 
			
		||||
            "StreamPeerStatusNotSupport",
 | 
			
		||||
            R"({"streams": ["peer_status"]})",
 | 
			
		||||
            "reportingUnsupported",
 | 
			
		||||
            "Requested operation not supported by reporting mode server"
 | 
			
		||||
            "StreamPeerStatusNotSupport", R"({"streams": ["peer_status"]})", "notSupported", "Operation not supported."
 | 
			
		||||
        },
 | 
			
		||||
        UnsubscribeParamTestCaseBundle{
 | 
			
		||||
            "StreamConsensusNotSupport",
 | 
			
		||||
            R"({"streams": ["consensus"]})",
 | 
			
		||||
            "reportingUnsupported",
 | 
			
		||||
            "Requested operation not supported by reporting mode server"
 | 
			
		||||
            "StreamConsensusNotSupport", R"({"streams": ["consensus"]})", "notSupported", "Operation not supported."
 | 
			
		||||
        },
 | 
			
		||||
        UnsubscribeParamTestCaseBundle{
 | 
			
		||||
            "StreamServerNotSupport",
 | 
			
		||||
            R"({"streams": ["server"]})",
 | 
			
		||||
            "reportingUnsupported",
 | 
			
		||||
            "Requested operation not supported by reporting mode server"
 | 
			
		||||
            "StreamServerNotSupport", R"({"streams": ["server"]})", "notSupported", "Operation not supported."
 | 
			
		||||
        },
 | 
			
		||||
    };
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -18,33 +18,69 @@
 | 
			
		||||
//==============================================================================
 | 
			
		||||
 | 
			
		||||
#include "util/newconfig/Array.hpp"
 | 
			
		||||
#include "util/newconfig/ConfigConstraints.hpp"
 | 
			
		||||
#include "util/newconfig/ConfigValue.hpp"
 | 
			
		||||
#include "util/newconfig/Types.hpp"
 | 
			
		||||
#include "util/newconfig/ValueView.hpp"
 | 
			
		||||
 | 
			
		||||
#include <gtest/gtest.h>
 | 
			
		||||
 | 
			
		||||
#include <algorithm>
 | 
			
		||||
#include <cstdint>
 | 
			
		||||
#include <vector>
 | 
			
		||||
 | 
			
		||||
using namespace util::config;
 | 
			
		||||
 | 
			
		||||
TEST(ArrayTest, testConfigArray)
 | 
			
		||||
TEST(ArrayTest, addSingleValue)
 | 
			
		||||
{
 | 
			
		||||
    auto arr = Array{
 | 
			
		||||
        ConfigValue{ConfigType::Boolean}.defaultValue(false),
 | 
			
		||||
        ConfigValue{ConfigType::Integer}.defaultValue(1234),
 | 
			
		||||
        ConfigValue{ConfigType::Double}.defaultValue(22.22),
 | 
			
		||||
    };
 | 
			
		||||
    auto cv = arr.at(0);
 | 
			
		||||
    ValueView const vv{cv};
 | 
			
		||||
    EXPECT_EQ(vv.asBool(), false);
 | 
			
		||||
    auto arr = Array{ConfigValue{ConfigType::Double}};
 | 
			
		||||
    arr.addValue(111.11);
 | 
			
		||||
    EXPECT_EQ(arr.size(), 1);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
    auto cv2 = arr.at(1);
 | 
			
		||||
TEST(ArrayTest, addAndCheckMultipleValues)
 | 
			
		||||
{
 | 
			
		||||
    auto arr = Array{ConfigValue{ConfigType::Double}};
 | 
			
		||||
    arr.addValue(111.11);
 | 
			
		||||
    arr.addValue(222.22);
 | 
			
		||||
    arr.addValue(333.33);
 | 
			
		||||
    EXPECT_EQ(arr.size(), 3);
 | 
			
		||||
 | 
			
		||||
    auto const cv = arr.at(0);
 | 
			
		||||
    ValueView const vv{cv};
 | 
			
		||||
    EXPECT_EQ(vv.asDouble(), 111.11);
 | 
			
		||||
 | 
			
		||||
    auto const cv2 = arr.at(1);
 | 
			
		||||
    ValueView const vv2{cv2};
 | 
			
		||||
    EXPECT_EQ(vv2.asIntType<int>(), 1234);
 | 
			
		||||
    EXPECT_EQ(vv2.asDouble(), 222.22);
 | 
			
		||||
 | 
			
		||||
    EXPECT_EQ(arr.size(), 3);
 | 
			
		||||
    arr.emplaceBack(ConfigValue{ConfigType::String}.defaultValue("false"));
 | 
			
		||||
    arr.addValue(444.44);
 | 
			
		||||
 | 
			
		||||
    EXPECT_EQ(arr.size(), 4);
 | 
			
		||||
    auto cv4 = arr.at(3);
 | 
			
		||||
    auto const cv4 = arr.at(3);
 | 
			
		||||
    ValueView const vv4{cv4};
 | 
			
		||||
    EXPECT_EQ(vv4.asString(), "false");
 | 
			
		||||
    EXPECT_EQ(vv4.asDouble(), 444.44);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST(ArrayTest, testArrayPattern)
 | 
			
		||||
{
 | 
			
		||||
    auto const arr = Array{ConfigValue{ConfigType::String}};
 | 
			
		||||
    auto const arrPattern = arr.getArrayPattern();
 | 
			
		||||
    EXPECT_EQ(arrPattern.type(), ConfigType::String);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST(ArrayTest, iterateValueArray)
 | 
			
		||||
{
 | 
			
		||||
    auto arr = Array{ConfigValue{ConfigType::Integer}.withConstraint(validateUint16)};
 | 
			
		||||
    std::vector<int64_t> const expected{543, 123, 909};
 | 
			
		||||
 | 
			
		||||
    for (auto const num : expected)
 | 
			
		||||
        arr.addValue(num);
 | 
			
		||||
 | 
			
		||||
    std::vector<int64_t> actual;
 | 
			
		||||
    for (auto it = arr.begin(); it != arr.end(); ++it)
 | 
			
		||||
        actual.emplace_back(std::get<int64_t>(it->getValue()));
 | 
			
		||||
 | 
			
		||||
    EXPECT_TRUE(std::ranges::equal(expected, actual));
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -19,11 +19,13 @@
 | 
			
		||||
 | 
			
		||||
#include "util/newconfig/ArrayView.hpp"
 | 
			
		||||
#include "util/newconfig/ConfigDefinition.hpp"
 | 
			
		||||
#include "util/newconfig/ConfigValue.hpp"
 | 
			
		||||
#include "util/newconfig/ConfigFileJson.hpp"
 | 
			
		||||
#include "util/newconfig/FakeConfigData.hpp"
 | 
			
		||||
#include "util/newconfig/ObjectView.hpp"
 | 
			
		||||
#include "util/newconfig/Types.hpp"
 | 
			
		||||
#include "util/newconfig/ValueView.hpp"
 | 
			
		||||
 | 
			
		||||
#include <boost/json/parse.hpp>
 | 
			
		||||
#include <gtest/gtest.h>
 | 
			
		||||
 | 
			
		||||
#include <cstddef>
 | 
			
		||||
@@ -31,33 +33,65 @@
 | 
			
		||||
using namespace util::config;
 | 
			
		||||
 | 
			
		||||
struct ArrayViewTest : testing::Test {
 | 
			
		||||
    ClioConfigDefinition const configData = generateConfig();
 | 
			
		||||
    ArrayViewTest()
 | 
			
		||||
    {
 | 
			
		||||
        ConfigFileJson const jsonFileObj{boost::json::parse(JSONData).as_object()};
 | 
			
		||||
        auto const errors = configData.parse(jsonFileObj);
 | 
			
		||||
        EXPECT_TRUE(!errors.has_value());
 | 
			
		||||
    }
 | 
			
		||||
    ClioConfigDefinition configData = generateConfig();
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
TEST_F(ArrayViewTest, ArrayValueTest)
 | 
			
		||||
// Array View tests can only be tested after the values are populated from user Config
 | 
			
		||||
// into ConfigClioDefinition
 | 
			
		||||
TEST_F(ArrayViewTest, ArrayGetValueDouble)
 | 
			
		||||
{
 | 
			
		||||
    ArrayView const arrVals = configData.getArray("array.[].sub");
 | 
			
		||||
    auto valIt = arrVals.begin<ValueView>();
 | 
			
		||||
    auto const precision = 1e-9;
 | 
			
		||||
    EXPECT_NEAR((*valIt++).asDouble(), 111.11, precision);
 | 
			
		||||
    EXPECT_NEAR((*valIt++).asDouble(), 4321.55, precision);
 | 
			
		||||
    EXPECT_EQ(valIt, arrVals.end<ValueView>());
 | 
			
		||||
    ArrayView const arrVals = configData.getArray("array.[].sub");
 | 
			
		||||
 | 
			
		||||
    EXPECT_NEAR(111.11, arrVals.valueAt(0).asDouble(), precision);
 | 
			
		||||
    auto const firstVal = arrVals.valueAt(0);
 | 
			
		||||
    EXPECT_EQ(firstVal.type(), ConfigType::Double);
 | 
			
		||||
    EXPECT_TRUE(firstVal.hasValue());
 | 
			
		||||
    EXPECT_FALSE(firstVal.isOptional());
 | 
			
		||||
 | 
			
		||||
    EXPECT_NEAR(111.11, firstVal.asDouble(), precision);
 | 
			
		||||
    EXPECT_NEAR(4321.55, arrVals.valueAt(1).asDouble(), precision);
 | 
			
		||||
 | 
			
		||||
    ArrayView const arrVals2 = configData.getArray("array.[].sub2");
 | 
			
		||||
    auto val2It = arrVals2.begin<ValueView>();
 | 
			
		||||
    EXPECT_EQ((*val2It++).asString(), "subCategory");
 | 
			
		||||
    EXPECT_EQ((*val2It++).asString(), "temporary");
 | 
			
		||||
    EXPECT_EQ(val2It, arrVals2.end<ValueView>());
 | 
			
		||||
 | 
			
		||||
    ValueView const tempVal = arrVals2.valueAt(0);
 | 
			
		||||
    EXPECT_EQ(tempVal.type(), ConfigType::String);
 | 
			
		||||
    EXPECT_EQ("subCategory", tempVal.asString());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(ArrayViewTest, ArrayWithObjTest)
 | 
			
		||||
TEST_F(ArrayViewTest, ArrayGetValueString)
 | 
			
		||||
{
 | 
			
		||||
    ArrayView const arrVals = configData.getArray("array.[].sub2");
 | 
			
		||||
    ValueView const firstVal = arrVals.valueAt(0);
 | 
			
		||||
 | 
			
		||||
    EXPECT_EQ(firstVal.type(), ConfigType::String);
 | 
			
		||||
    EXPECT_EQ("subCategory", firstVal.asString());
 | 
			
		||||
    EXPECT_EQ("london", arrVals.valueAt(2).asString());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(ArrayViewTest, IterateValuesDouble)
 | 
			
		||||
{
 | 
			
		||||
    auto const precision = 1e-9;
 | 
			
		||||
    ArrayView const arrVals = configData.getArray("array.[].sub");
 | 
			
		||||
 | 
			
		||||
    auto valIt = arrVals.begin<ValueView>();
 | 
			
		||||
    EXPECT_NEAR((*valIt++).asDouble(), 111.11, precision);
 | 
			
		||||
    EXPECT_NEAR((*valIt++).asDouble(), 4321.55, precision);
 | 
			
		||||
    EXPECT_NEAR((*valIt++).asDouble(), 5555.44, precision);
 | 
			
		||||
    EXPECT_EQ(valIt, arrVals.end<ValueView>());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(ArrayViewTest, IterateValuesString)
 | 
			
		||||
{
 | 
			
		||||
    ArrayView const arrVals = configData.getArray("array.[].sub2");
 | 
			
		||||
 | 
			
		||||
    auto val2It = arrVals.begin<ValueView>();
 | 
			
		||||
    EXPECT_EQ((*val2It++).asString(), "subCategory");
 | 
			
		||||
    EXPECT_EQ((*val2It++).asString(), "temporary");
 | 
			
		||||
    EXPECT_EQ((*val2It++).asString(), "london");
 | 
			
		||||
    EXPECT_EQ(val2It, arrVals.end<ValueView>());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(ArrayViewTest, ArrayWithObj)
 | 
			
		||||
{
 | 
			
		||||
    ArrayView const arrVals = configData.getArray("array.[]");
 | 
			
		||||
    ArrayView const arrValAlt = configData.getArray("array");
 | 
			
		||||
@@ -73,20 +107,19 @@ TEST_F(ArrayViewTest, IterateArray)
 | 
			
		||||
{
 | 
			
		||||
    auto arr = configData.getArray("dosguard.whitelist");
 | 
			
		||||
    EXPECT_EQ(2, arr.size());
 | 
			
		||||
    EXPECT_EQ(arr.valueAt(0).asString(), "125.5.5.2");
 | 
			
		||||
    EXPECT_EQ(arr.valueAt(1).asString(), "204.2.2.2");
 | 
			
		||||
    EXPECT_EQ(arr.valueAt(0).asString(), "125.5.5.1");
 | 
			
		||||
    EXPECT_EQ(arr.valueAt(1).asString(), "204.2.2.1");
 | 
			
		||||
 | 
			
		||||
    auto it = arr.begin<ValueView>();
 | 
			
		||||
    EXPECT_EQ((*it++).asString(), "125.5.5.2");
 | 
			
		||||
    EXPECT_EQ((*it++).asString(), "204.2.2.2");
 | 
			
		||||
    EXPECT_EQ((*it++).asString(), "125.5.5.1");
 | 
			
		||||
    EXPECT_EQ((*it++).asString(), "204.2.2.1");
 | 
			
		||||
    EXPECT_EQ((it), arr.end<ValueView>());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(ArrayViewTest, DifferentArrayIterators)
 | 
			
		||||
TEST_F(ArrayViewTest, CompareDifferentArrayIterators)
 | 
			
		||||
{
 | 
			
		||||
    auto const subArray = configData.getArray("array.[].sub");
 | 
			
		||||
    auto const dosguardArray = configData.getArray("dosguard.whitelist.[]");
 | 
			
		||||
    ASSERT_EQ(subArray.size(), dosguardArray.size());
 | 
			
		||||
 | 
			
		||||
    auto itArray = subArray.begin<ValueView>();
 | 
			
		||||
    auto itDosguard = dosguardArray.begin<ValueView>();
 | 
			
		||||
@@ -98,7 +131,7 @@ TEST_F(ArrayViewTest, DifferentArrayIterators)
 | 
			
		||||
TEST_F(ArrayViewTest, IterateObject)
 | 
			
		||||
{
 | 
			
		||||
    auto arr = configData.getArray("array");
 | 
			
		||||
    EXPECT_EQ(2, arr.size());
 | 
			
		||||
    EXPECT_EQ(3, arr.size());
 | 
			
		||||
 | 
			
		||||
    auto it = arr.begin<ObjectView>();
 | 
			
		||||
    EXPECT_EQ(111.11, (*it).getValue("sub").asDouble());
 | 
			
		||||
@@ -107,33 +140,37 @@ TEST_F(ArrayViewTest, IterateObject)
 | 
			
		||||
    EXPECT_EQ(4321.55, (*it).getValue("sub").asDouble());
 | 
			
		||||
    EXPECT_EQ("temporary", (*it++).getValue("sub2").asString());
 | 
			
		||||
 | 
			
		||||
    EXPECT_EQ(5555.44, (*it).getValue("sub").asDouble());
 | 
			
		||||
    EXPECT_EQ("london", (*it++).getValue("sub2").asString());
 | 
			
		||||
 | 
			
		||||
    EXPECT_EQ(it, arr.end<ObjectView>());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
struct ArrayViewDeathTest : ArrayViewTest {};
 | 
			
		||||
 | 
			
		||||
TEST_F(ArrayViewDeathTest, IncorrectAccess)
 | 
			
		||||
TEST_F(ArrayViewDeathTest, AccessArrayOutOfBounce)
 | 
			
		||||
{
 | 
			
		||||
    ArrayView const arr = configData.getArray("higher");
 | 
			
		||||
    // dies because higher only has 1 object (trying to access 2nd element)
 | 
			
		||||
    EXPECT_DEATH({ [[maybe_unused]] auto _ = configData.getArray("higher").objectAt(1); }, ".*");
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
    // dies because higher only has 1 object
 | 
			
		||||
    EXPECT_DEATH({ [[maybe_unused]] auto _ = arr.objectAt(1); }, ".*");
 | 
			
		||||
 | 
			
		||||
    ArrayView const arrVals2 = configData.getArray("array.[].sub2");
 | 
			
		||||
    ValueView const tempVal = arrVals2.valueAt(0);
 | 
			
		||||
 | 
			
		||||
    // dies because array.[].sub2 only has 2 config values
 | 
			
		||||
    EXPECT_DEATH([[maybe_unused]] auto _ = arrVals2.valueAt(2), ".*");
 | 
			
		||||
TEST_F(ArrayViewDeathTest, AccessIndexOfWrongType)
 | 
			
		||||
{
 | 
			
		||||
    auto const& arrVals2 = configData.getArray("array.[].sub2");
 | 
			
		||||
    auto const& tempVal = arrVals2.valueAt(0);
 | 
			
		||||
 | 
			
		||||
    // dies as value is not of type int
 | 
			
		||||
    EXPECT_DEATH({ [[maybe_unused]] auto _ = tempVal.asIntType<int>(); }, ".*");
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(ArrayViewDeathTest, IncorrectIterateAccess)
 | 
			
		||||
TEST_F(ArrayViewDeathTest, GetValueWhenItIsObject)
 | 
			
		||||
{
 | 
			
		||||
    ArrayView const arr = configData.getArray("higher");
 | 
			
		||||
    EXPECT_DEATH({ [[maybe_unused]] auto _ = arr.begin<ValueView>(); }, ".*");
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(ArrayViewDeathTest, GetObjectWhenItIsValue)
 | 
			
		||||
{
 | 
			
		||||
    ArrayView const dosguardWhitelist = configData.getArray("dosguard.whitelist");
 | 
			
		||||
    EXPECT_DEATH({ [[maybe_unused]] auto _ = dosguardWhitelist.begin<ObjectView>(); }, ".*");
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -20,17 +20,25 @@
 | 
			
		||||
#include "util/newconfig/ArrayView.hpp"
 | 
			
		||||
#include "util/newconfig/ConfigDefinition.hpp"
 | 
			
		||||
#include "util/newconfig/ConfigDescription.hpp"
 | 
			
		||||
#include "util/newconfig/ConfigValue.hpp"
 | 
			
		||||
#include "util/newconfig/ConfigFileJson.hpp"
 | 
			
		||||
#include "util/newconfig/FakeConfigData.hpp"
 | 
			
		||||
#include "util/newconfig/Types.hpp"
 | 
			
		||||
#include "util/newconfig/ValueView.hpp"
 | 
			
		||||
 | 
			
		||||
#include <boost/json/object.hpp>
 | 
			
		||||
#include <boost/json/parse.hpp>
 | 
			
		||||
#include <boost/json/value.hpp>
 | 
			
		||||
#include <gtest/gtest.h>
 | 
			
		||||
 | 
			
		||||
#include <algorithm>
 | 
			
		||||
#include <cstdint>
 | 
			
		||||
#include <string>
 | 
			
		||||
#include <string_view>
 | 
			
		||||
#include <unordered_set>
 | 
			
		||||
#include <vector>
 | 
			
		||||
 | 
			
		||||
using namespace util::config;
 | 
			
		||||
 | 
			
		||||
// TODO: parsing config file and populating into config will be here once implemented
 | 
			
		||||
struct NewConfigTest : testing::Test {
 | 
			
		||||
    ClioConfigDefinition const configData = generateConfig();
 | 
			
		||||
};
 | 
			
		||||
@@ -45,12 +53,9 @@ TEST_F(NewConfigTest, fetchValues)
 | 
			
		||||
    EXPECT_EQ(true, configData.getValue("header.admin").asBool());
 | 
			
		||||
    EXPECT_EQ("TSM", configData.getValue("header.sub.sub2Value").asString());
 | 
			
		||||
    EXPECT_EQ(444.22, configData.getValue("ip").asDouble());
 | 
			
		||||
 | 
			
		||||
    auto const v2 = configData.getValueInArray("dosguard.whitelist", 0);
 | 
			
		||||
    EXPECT_EQ(v2.asString(), "125.5.5.2");
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(NewConfigTest, fetchObject)
 | 
			
		||||
TEST_F(NewConfigTest, fetchObjectDirectly)
 | 
			
		||||
{
 | 
			
		||||
    auto const obj = configData.getObject("header");
 | 
			
		||||
    EXPECT_TRUE(obj.containsKey("sub.sub2Value"));
 | 
			
		||||
@@ -58,27 +63,6 @@ TEST_F(NewConfigTest, fetchObject)
 | 
			
		||||
    auto const obj2 = obj.getObject("sub");
 | 
			
		||||
    EXPECT_TRUE(obj2.containsKey("sub2Value"));
 | 
			
		||||
    EXPECT_EQ(obj2.getValue("sub2Value").asString(), "TSM");
 | 
			
		||||
 | 
			
		||||
    auto const objInArr = configData.getObject("array", 0);
 | 
			
		||||
    auto const obj2InArr = configData.getObject("array", 1);
 | 
			
		||||
    EXPECT_EQ(objInArr.getValue("sub").asDouble(), 111.11);
 | 
			
		||||
    EXPECT_EQ(objInArr.getValue("sub2").asString(), "subCategory");
 | 
			
		||||
    EXPECT_EQ(obj2InArr.getValue("sub").asDouble(), 4321.55);
 | 
			
		||||
    EXPECT_EQ(obj2InArr.getValue("sub2").asString(), "temporary");
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(NewConfigTest, fetchArray)
 | 
			
		||||
{
 | 
			
		||||
    auto const obj = configData.getObject("dosguard");
 | 
			
		||||
    EXPECT_TRUE(obj.containsKey("whitelist.[]"));
 | 
			
		||||
 | 
			
		||||
    auto const arr = obj.getArray("whitelist");
 | 
			
		||||
    EXPECT_EQ(2, arr.size());
 | 
			
		||||
 | 
			
		||||
    auto const sameArr = configData.getArray("dosguard.whitelist");
 | 
			
		||||
    EXPECT_EQ(2, sameArr.size());
 | 
			
		||||
    EXPECT_EQ(sameArr.valueAt(0).asString(), arr.valueAt(0).asString());
 | 
			
		||||
    EXPECT_EQ(sameArr.valueAt(1).asString(), arr.valueAt(1).asString());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(NewConfigTest, CheckKeys)
 | 
			
		||||
@@ -91,9 +75,11 @@ TEST_F(NewConfigTest, CheckKeys)
 | 
			
		||||
    EXPECT_TRUE(configData.hasItemsWithPrefix("dosguard"));
 | 
			
		||||
    EXPECT_TRUE(configData.hasItemsWithPrefix("ip"));
 | 
			
		||||
 | 
			
		||||
    EXPECT_EQ(configData.arraySize("array"), 2);
 | 
			
		||||
    EXPECT_EQ(configData.arraySize("higher"), 1);
 | 
			
		||||
    EXPECT_EQ(configData.arraySize("dosguard.whitelist"), 2);
 | 
			
		||||
    // all arrays currently not populated, only has "itemPattern_" that defines
 | 
			
		||||
    // the type/constraint each configValue will have later on
 | 
			
		||||
    EXPECT_EQ(configData.arraySize("array"), 0);
 | 
			
		||||
    EXPECT_EQ(configData.arraySize("higher"), 0);
 | 
			
		||||
    EXPECT_EQ(configData.arraySize("dosguard.whitelist"), 0);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(NewConfigTest, CheckAllKeys)
 | 
			
		||||
@@ -110,7 +96,10 @@ TEST_F(NewConfigTest, CheckAllKeys)
 | 
			
		||||
        "higher.[].low.section",
 | 
			
		||||
        "higher.[].low.admin",
 | 
			
		||||
        "dosguard.whitelist.[]",
 | 
			
		||||
        "dosguard.port"
 | 
			
		||||
        "dosguard.port",
 | 
			
		||||
        "optional.withDefault",
 | 
			
		||||
        "optional.withNoDefault",
 | 
			
		||||
        "requireValue"
 | 
			
		||||
    };
 | 
			
		||||
 | 
			
		||||
    for (auto i = configData.begin(); i != configData.end(); ++i) {
 | 
			
		||||
@@ -121,31 +110,42 @@ TEST_F(NewConfigTest, CheckAllKeys)
 | 
			
		||||
 | 
			
		||||
struct NewConfigDeathTest : NewConfigTest {};
 | 
			
		||||
 | 
			
		||||
TEST_F(NewConfigDeathTest, IncorrectGetValues)
 | 
			
		||||
TEST_F(NewConfigDeathTest, GetNonExistentKeys)
 | 
			
		||||
{
 | 
			
		||||
    EXPECT_DEATH({ [[maybe_unused]] auto a_ = configData.getValue("head"); }, ".*");
 | 
			
		||||
    EXPECT_DEATH({ [[maybe_unused]] auto a_ = configData.getValue("head."); }, ".*");
 | 
			
		||||
    EXPECT_DEATH({ [[maybe_unused]] auto a_ = configData.getValue("asdf"); }, ".*");
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(NewConfigDeathTest, GetValueButIsArray)
 | 
			
		||||
{
 | 
			
		||||
    EXPECT_DEATH({ [[maybe_unused]] auto a_ = configData.getValue("dosguard.whitelist"); }, ".*");
 | 
			
		||||
    EXPECT_DEATH({ [[maybe_unused]] auto a_ = configData.getValue("dosguard.whitelist.[]"); }, ".*");
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(NewConfigDeathTest, IncorrectGetObject)
 | 
			
		||||
TEST_F(NewConfigDeathTest, GetNonExistentObjectKey)
 | 
			
		||||
{
 | 
			
		||||
    ASSERT_FALSE(configData.contains("head"));
 | 
			
		||||
    EXPECT_DEATH({ [[maybe_unused]] auto a_ = configData.getObject("head"); }, ".*");
 | 
			
		||||
    EXPECT_DEATH({ [[maybe_unused]] auto a_ = configData.getObject("array"); }, ".*");
 | 
			
		||||
    EXPECT_DEATH({ [[maybe_unused]] auto a_ = configData.getObject("array", 2); }, ".*");
 | 
			
		||||
    EXPECT_DEATH({ [[maybe_unused]] auto a_ = configData.getObject("doesNotExist"); }, ".*");
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(NewConfigDeathTest, IncorrectGetArray)
 | 
			
		||||
TEST_F(NewConfigDeathTest, GetObjectButIsArray)
 | 
			
		||||
{
 | 
			
		||||
    EXPECT_DEATH({ [[maybe_unused]] auto a_ = configData.getObject("array"); }, ".*");
 | 
			
		||||
    EXPECT_DEATH({ [[maybe_unused]] auto a_ = configData.getObject("array", 2); }, ".*");
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(NewConfigDeathTest, GetArrayButIsValue)
 | 
			
		||||
{
 | 
			
		||||
    EXPECT_DEATH({ [[maybe_unused]] auto a_ = configData.getArray("header.text1"); }, ".*");
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(NewConfigDeathTest, GetNonExistentArrayKey)
 | 
			
		||||
{
 | 
			
		||||
    EXPECT_DEATH({ [[maybe_unused]] auto a_ = configData.getArray("asdf"); }, ".*");
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST(ConfigDescription, getValues)
 | 
			
		||||
TEST(ConfigDescription, GetValues)
 | 
			
		||||
{
 | 
			
		||||
    ClioConfigDescription const definition{};
 | 
			
		||||
 | 
			
		||||
@@ -154,10 +154,150 @@ TEST(ConfigDescription, getValues)
 | 
			
		||||
    EXPECT_EQ(definition.get("prometheus.enabled"), "Enable or disable Prometheus metrics.");
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST(ConfigDescriptionAssertDeathTest, nonExistingKeyTest)
 | 
			
		||||
TEST(ConfigDescriptionAssertDeathTest, NonExistingKeyTest)
 | 
			
		||||
{
 | 
			
		||||
    ClioConfigDescription const definition{};
 | 
			
		||||
 | 
			
		||||
    EXPECT_DEATH({ [[maybe_unused]] auto a = definition.get("data"); }, ".*");
 | 
			
		||||
    EXPECT_DEATH({ [[maybe_unused]] auto a = definition.get("etl_source.[]"); }, ".*");
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/** @brief Testing override the default values with the ones in Json */
 | 
			
		||||
struct OverrideConfigVals : testing::Test {
 | 
			
		||||
    OverrideConfigVals()
 | 
			
		||||
    {
 | 
			
		||||
        ConfigFileJson const jsonFileObj{boost::json::parse(JSONData).as_object()};
 | 
			
		||||
        auto const errors = configData.parse(jsonFileObj);
 | 
			
		||||
        EXPECT_TRUE(!errors.has_value());
 | 
			
		||||
    }
 | 
			
		||||
    ClioConfigDefinition configData = generateConfig();
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
TEST_F(OverrideConfigVals, ValidateValuesStrings)
 | 
			
		||||
{
 | 
			
		||||
    // make sure the values in configData are overriden
 | 
			
		||||
    EXPECT_TRUE(configData.contains("header.text1"));
 | 
			
		||||
    EXPECT_EQ(configData.getValue("header.text1").asString(), "value");
 | 
			
		||||
 | 
			
		||||
    EXPECT_FALSE(configData.contains("header.sub"));
 | 
			
		||||
    EXPECT_TRUE(configData.contains("header.sub.sub2Value"));
 | 
			
		||||
    EXPECT_EQ(configData.getValue("header.sub.sub2Value").asString(), "TSM");
 | 
			
		||||
 | 
			
		||||
    EXPECT_TRUE(configData.contains("requireValue"));
 | 
			
		||||
    EXPECT_EQ(configData.getValue("requireValue").asString(), "required");
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(OverrideConfigVals, ValidateValuesDouble)
 | 
			
		||||
{
 | 
			
		||||
    EXPECT_TRUE(configData.contains("optional.withDefault"));
 | 
			
		||||
    EXPECT_EQ(configData.getValue("optional.withDefault").asDouble(), 0.0);
 | 
			
		||||
 | 
			
		||||
    // make sure the values not overwritten, (default values) are there too
 | 
			
		||||
    EXPECT_TRUE(configData.contains("ip"));
 | 
			
		||||
    EXPECT_EQ(configData.getValue("ip").asDouble(), 444.22);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(OverrideConfigVals, ValidateValuesInteger)
 | 
			
		||||
{
 | 
			
		||||
    EXPECT_TRUE(configData.contains("dosguard.port"));
 | 
			
		||||
    EXPECT_EQ(configData.getValue("dosguard.port").asIntType<int>(), 44444);
 | 
			
		||||
 | 
			
		||||
    EXPECT_TRUE(configData.contains("header.port"));
 | 
			
		||||
    EXPECT_EQ(configData.getValue("header.port").asIntType<int64_t>(), 321);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(OverrideConfigVals, ValidateValuesBool)
 | 
			
		||||
{
 | 
			
		||||
    EXPECT_TRUE(configData.contains("header.admin"));
 | 
			
		||||
    EXPECT_EQ(configData.getValue("header.admin").asBool(), false);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(OverrideConfigVals, ValidateIntegerValuesInArrays)
 | 
			
		||||
{
 | 
			
		||||
    // Check array values (sub)
 | 
			
		||||
    EXPECT_TRUE(configData.contains("array.[].sub"));
 | 
			
		||||
    auto const arrSub = configData.getArray("array.[].sub");
 | 
			
		||||
 | 
			
		||||
    std::vector<double> expectedArrSubVal{111.11, 4321.55, 5555.44};
 | 
			
		||||
    std::vector<double> actualArrSubVal{};
 | 
			
		||||
    for (auto it = arrSub.begin<ValueView>(); it != arrSub.end<ValueView>(); ++it) {
 | 
			
		||||
        actualArrSubVal.emplace_back((*it).asDouble());
 | 
			
		||||
    }
 | 
			
		||||
    EXPECT_TRUE(std::ranges::equal(expectedArrSubVal, actualArrSubVal));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(OverrideConfigVals, ValidateStringValuesInArrays)
 | 
			
		||||
{
 | 
			
		||||
    // Check array values (sub2)
 | 
			
		||||
    EXPECT_TRUE(configData.contains("array.[].sub2"));
 | 
			
		||||
    auto const arrSub2 = configData.getArray("array.[].sub2");
 | 
			
		||||
 | 
			
		||||
    std::vector<std::string> expectedArrSub2Val{"subCategory", "temporary", "london"};
 | 
			
		||||
    std::vector<std::string> actualArrSub2Val{};
 | 
			
		||||
    for (auto it = arrSub2.begin<ValueView>(); it != arrSub2.end<ValueView>(); ++it) {
 | 
			
		||||
        actualArrSub2Val.emplace_back((*it).asString());
 | 
			
		||||
    }
 | 
			
		||||
    EXPECT_TRUE(std::ranges::equal(expectedArrSub2Val, actualArrSub2Val));
 | 
			
		||||
 | 
			
		||||
    // Check dosguard values
 | 
			
		||||
    EXPECT_TRUE(configData.contains("dosguard.whitelist.[]"));
 | 
			
		||||
    auto const dosguard = configData.getArray("dosguard.whitelist.[]");
 | 
			
		||||
    EXPECT_EQ("125.5.5.1", dosguard.valueAt(0).asString());
 | 
			
		||||
    EXPECT_EQ("204.2.2.1", dosguard.valueAt(1).asString());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(OverrideConfigVals, FetchArray)
 | 
			
		||||
{
 | 
			
		||||
    auto const obj = configData.getObject("dosguard");
 | 
			
		||||
    EXPECT_TRUE(obj.containsKey("whitelist.[]"));
 | 
			
		||||
 | 
			
		||||
    auto const arr = obj.getArray("whitelist");
 | 
			
		||||
    EXPECT_EQ(2, arr.size());
 | 
			
		||||
 | 
			
		||||
    auto const sameArr = configData.getArray("dosguard.whitelist");
 | 
			
		||||
    EXPECT_EQ(2, sameArr.size());
 | 
			
		||||
    EXPECT_EQ(sameArr.valueAt(0).asString(), arr.valueAt(0).asString());
 | 
			
		||||
    EXPECT_EQ(sameArr.valueAt(1).asString(), arr.valueAt(1).asString());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(OverrideConfigVals, FetchObjectByArray)
 | 
			
		||||
{
 | 
			
		||||
    auto const objInArr = configData.getObject("array", 0);
 | 
			
		||||
    auto const obj2InArr = configData.getObject("array", 1);
 | 
			
		||||
    auto const obj3InArr = configData.getObject("array", 2);
 | 
			
		||||
 | 
			
		||||
    EXPECT_EQ(objInArr.getValue("sub").asDouble(), 111.11);
 | 
			
		||||
    EXPECT_EQ(objInArr.getValue("sub2").asString(), "subCategory");
 | 
			
		||||
    EXPECT_EQ(obj2InArr.getValue("sub").asDouble(), 4321.55);
 | 
			
		||||
    EXPECT_EQ(obj2InArr.getValue("sub2").asString(), "temporary");
 | 
			
		||||
    EXPECT_EQ(obj3InArr.getValue("sub").asDouble(), 5555.44);
 | 
			
		||||
    EXPECT_EQ(obj3InArr.getValue("sub2").asString(), "london");
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
struct IncorrectOverrideValues : testing::Test {
 | 
			
		||||
    ClioConfigDefinition configData = generateConfig();
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
TEST_F(IncorrectOverrideValues, InvalidJsonErrors)
 | 
			
		||||
{
 | 
			
		||||
    ConfigFileJson const jsonFileObj{boost::json::parse(invalidJSONData).as_object()};
 | 
			
		||||
    auto const errors = configData.parse(jsonFileObj);
 | 
			
		||||
    EXPECT_TRUE(errors.has_value());
 | 
			
		||||
 | 
			
		||||
    // Expected error messages
 | 
			
		||||
    std::unordered_set<std::string_view> const expectedErrors{
 | 
			
		||||
        "dosguard.whitelist.[] value does not match type string",
 | 
			
		||||
        "higher.[].low.section key is required in user Config",
 | 
			
		||||
        "higher.[].low.admin key is required in user Config",
 | 
			
		||||
        "array.[].sub key is required in user Config",
 | 
			
		||||
        "header.port value does not match type integer",
 | 
			
		||||
        "header.admin value does not match type boolean",
 | 
			
		||||
        "optional.withDefault value does not match type double"
 | 
			
		||||
    };
 | 
			
		||||
 | 
			
		||||
    std::unordered_set<std::string_view> actualErrors;
 | 
			
		||||
    for (auto const& error : errors.value()) {
 | 
			
		||||
        actualErrors.insert(error.error);
 | 
			
		||||
    }
 | 
			
		||||
    EXPECT_EQ(expectedErrors, actualErrors);
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -17,24 +17,207 @@
 | 
			
		||||
*/
 | 
			
		||||
//==============================================================================
 | 
			
		||||
 | 
			
		||||
#include "util/newconfig/ConfigConstraints.hpp"
 | 
			
		||||
#include "util/newconfig/ConfigValue.hpp"
 | 
			
		||||
#include "util/newconfig/Types.hpp"
 | 
			
		||||
 | 
			
		||||
#include <fmt/core.h>
 | 
			
		||||
#include <gtest/gtest.h>
 | 
			
		||||
 | 
			
		||||
#include <array>
 | 
			
		||||
#include <string>
 | 
			
		||||
 | 
			
		||||
using namespace util::config;
 | 
			
		||||
 | 
			
		||||
TEST(ConfigValue, testConfigValue)
 | 
			
		||||
TEST(ConfigValue, GetSetString)
 | 
			
		||||
{
 | 
			
		||||
    auto cvStr = ConfigValue{ConfigType::String}.defaultValue("12345");
 | 
			
		||||
    auto const cvStr = ConfigValue{ConfigType::String}.defaultValue("12345");
 | 
			
		||||
    EXPECT_EQ(cvStr.type(), ConfigType::String);
 | 
			
		||||
    EXPECT_TRUE(cvStr.hasValue());
 | 
			
		||||
    EXPECT_FALSE(cvStr.isOptional());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
    auto cvInt = ConfigValue{ConfigType::Integer}.defaultValue(543);
 | 
			
		||||
TEST(ConfigValue, GetSetInteger)
 | 
			
		||||
{
 | 
			
		||||
    auto const cvInt = ConfigValue{ConfigType::Integer}.defaultValue(543);
 | 
			
		||||
    EXPECT_EQ(cvInt.type(), ConfigType::Integer);
 | 
			
		||||
    EXPECT_TRUE(cvStr.hasValue());
 | 
			
		||||
    EXPECT_FALSE(cvStr.isOptional());
 | 
			
		||||
    EXPECT_TRUE(cvInt.hasValue());
 | 
			
		||||
    EXPECT_FALSE(cvInt.isOptional());
 | 
			
		||||
 | 
			
		||||
    auto cvOpt = ConfigValue{ConfigType::Integer}.optional();
 | 
			
		||||
    auto const cvOpt = ConfigValue{ConfigType::Integer}.optional();
 | 
			
		||||
    EXPECT_TRUE(cvOpt.isOptional());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// A test for each constraint so it's easy to change in the future
 | 
			
		||||
TEST(ConfigValue, PortConstraint)
 | 
			
		||||
{
 | 
			
		||||
    auto const portConstraint{PortConstraint{}};
 | 
			
		||||
    EXPECT_FALSE(portConstraint.checkConstraint(4444).has_value());
 | 
			
		||||
    EXPECT_TRUE(portConstraint.checkConstraint(99999).has_value());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST(ConfigValue, SetValuesOnPortConstraint)
 | 
			
		||||
{
 | 
			
		||||
    auto cvPort = ConfigValue{ConfigType::Integer}.defaultValue(4444).withConstraint(validatePort);
 | 
			
		||||
    auto const err = cvPort.setValue(99999);
 | 
			
		||||
    EXPECT_TRUE(err.has_value());
 | 
			
		||||
    EXPECT_EQ(err->error, "Port does not satisfy the constraint bounds");
 | 
			
		||||
    EXPECT_TRUE(cvPort.setValue(33.33).has_value());
 | 
			
		||||
    EXPECT_TRUE(cvPort.setValue(33.33).value().error == "value does not match type integer");
 | 
			
		||||
    EXPECT_FALSE(cvPort.setValue(1).has_value());
 | 
			
		||||
 | 
			
		||||
    auto cvPort2 = ConfigValue{ConfigType::String}.defaultValue("4444").withConstraint(validatePort);
 | 
			
		||||
    auto const strPortError = cvPort2.setValue("100000");
 | 
			
		||||
    EXPECT_TRUE(strPortError.has_value());
 | 
			
		||||
    EXPECT_EQ(strPortError->error, "Port does not satisfy the constraint bounds");
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST(ConfigValue, OneOfConstraintOneValue)
 | 
			
		||||
{
 | 
			
		||||
    std::array<char const*, 1> const arr = {"tracer"};
 | 
			
		||||
    auto const databaseConstraint{OneOf{"database.type", arr}};
 | 
			
		||||
    EXPECT_FALSE(databaseConstraint.checkConstraint("tracer").has_value());
 | 
			
		||||
 | 
			
		||||
    EXPECT_TRUE(databaseConstraint.checkConstraint(345).has_value());
 | 
			
		||||
    EXPECT_EQ(databaseConstraint.checkConstraint(345)->error, R"(Key "database.type"'s value must be a string)");
 | 
			
		||||
 | 
			
		||||
    EXPECT_TRUE(databaseConstraint.checkConstraint("123.44").has_value());
 | 
			
		||||
    EXPECT_EQ(
 | 
			
		||||
        databaseConstraint.checkConstraint("123.44")->error,
 | 
			
		||||
        R"(You provided value "123.44". Key "database.type"'s value must be one of the following: tracer)"
 | 
			
		||||
    );
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST(ConfigValue, OneOfConstraint)
 | 
			
		||||
{
 | 
			
		||||
    std::array<char const*, 3> const arr = {"123", "trace", "haha"};
 | 
			
		||||
    auto const oneOfCons{OneOf{"log_level", arr}};
 | 
			
		||||
 | 
			
		||||
    EXPECT_FALSE(oneOfCons.checkConstraint("trace").has_value());
 | 
			
		||||
 | 
			
		||||
    EXPECT_TRUE(oneOfCons.checkConstraint(345).has_value());
 | 
			
		||||
    EXPECT_EQ(oneOfCons.checkConstraint(345)->error, R"(Key "log_level"'s value must be a string)");
 | 
			
		||||
 | 
			
		||||
    EXPECT_TRUE(oneOfCons.checkConstraint("PETER_WAS_HERE").has_value());
 | 
			
		||||
    EXPECT_EQ(
 | 
			
		||||
        oneOfCons.checkConstraint("PETER_WAS_HERE")->error,
 | 
			
		||||
        R"(You provided value "PETER_WAS_HERE". Key "log_level"'s value must be one of the following: 123, trace, haha)"
 | 
			
		||||
    );
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST(ConfigValue, IpConstraint)
 | 
			
		||||
{
 | 
			
		||||
    auto ip = ConfigValue{ConfigType::String}.defaultValue("127.0.0.1").withConstraint(validateIP);
 | 
			
		||||
    EXPECT_FALSE(ip.setValue("http://127.0.0.1").has_value());
 | 
			
		||||
    EXPECT_FALSE(ip.setValue("http://127.0.0.1.com").has_value());
 | 
			
		||||
    auto const err = ip.setValue("123.44");
 | 
			
		||||
    EXPECT_TRUE(err.has_value());
 | 
			
		||||
    EXPECT_EQ(err->error, "Ip is not a valid ip address");
 | 
			
		||||
    EXPECT_FALSE(ip.setValue("126.0.0.2"));
 | 
			
		||||
 | 
			
		||||
    EXPECT_TRUE(ip.setValue("644.3.3.0"));
 | 
			
		||||
    EXPECT_TRUE(ip.setValue("127.0.0.1.0"));
 | 
			
		||||
    EXPECT_TRUE(ip.setValue(""));
 | 
			
		||||
    EXPECT_TRUE(ip.setValue("http://example..com"));
 | 
			
		||||
    EXPECT_FALSE(ip.setValue("localhost"));
 | 
			
		||||
    EXPECT_FALSE(ip.setValue("http://example.com:8080/path"));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST(ConfigValue, positiveNumConstraint)
 | 
			
		||||
{
 | 
			
		||||
    auto const numCons{NumberValueConstraint{0, 5}};
 | 
			
		||||
    EXPECT_FALSE(numCons.checkConstraint(0));
 | 
			
		||||
    EXPECT_FALSE(numCons.checkConstraint(5));
 | 
			
		||||
 | 
			
		||||
    EXPECT_TRUE(numCons.checkConstraint(true));
 | 
			
		||||
    EXPECT_EQ(numCons.checkConstraint(true)->error, fmt::format("Number must be of type integer"));
 | 
			
		||||
 | 
			
		||||
    EXPECT_TRUE(numCons.checkConstraint(8));
 | 
			
		||||
    EXPECT_EQ(numCons.checkConstraint(8)->error, fmt::format("Number must be between {} and {}", 0, 5));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST(ConfigValue, SetValuesOnNumberConstraint)
 | 
			
		||||
{
 | 
			
		||||
    auto positiveNum = ConfigValue{ConfigType::Integer}.defaultValue(20u).withConstraint(validateUint16);
 | 
			
		||||
    auto const err = positiveNum.setValue(-22, "key");
 | 
			
		||||
    EXPECT_TRUE(err.has_value());
 | 
			
		||||
    EXPECT_EQ(err->error, fmt::format("key Number must be between {} and {}", 0, 65535));
 | 
			
		||||
    EXPECT_FALSE(positiveNum.setValue(99, "key"));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST(ConfigValue, PositiveDoubleConstraint)
 | 
			
		||||
{
 | 
			
		||||
    auto const doubleCons{PositiveDouble{}};
 | 
			
		||||
    EXPECT_FALSE(doubleCons.checkConstraint(0.2));
 | 
			
		||||
    EXPECT_FALSE(doubleCons.checkConstraint(5.54));
 | 
			
		||||
    EXPECT_TRUE(doubleCons.checkConstraint("-5"));
 | 
			
		||||
    EXPECT_EQ(doubleCons.checkConstraint("-5")->error, "Double number must be of type int or double");
 | 
			
		||||
    EXPECT_EQ(doubleCons.checkConstraint(-5.6)->error, "Double number must be greater than 0");
 | 
			
		||||
    EXPECT_FALSE(doubleCons.checkConstraint(12.1));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
struct ConstraintTestBundle {
 | 
			
		||||
    std::string name;
 | 
			
		||||
    Constraint const& cons_;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
struct ConstraintDeathTest : public testing::Test, public testing::WithParamInterface<ConstraintTestBundle> {};
 | 
			
		||||
 | 
			
		||||
INSTANTIATE_TEST_SUITE_P(
 | 
			
		||||
    EachConstraints,
 | 
			
		||||
    ConstraintDeathTest,
 | 
			
		||||
    testing::Values(
 | 
			
		||||
        ConstraintTestBundle{"logTagConstraint", validateLogTag},
 | 
			
		||||
        ConstraintTestBundle{"portConstraint", validatePort},
 | 
			
		||||
        ConstraintTestBundle{"ipConstraint", validateIP},
 | 
			
		||||
        ConstraintTestBundle{"channelConstraint", validateChannelName},
 | 
			
		||||
        ConstraintTestBundle{"logLevelConstraint", validateLogLevelName},
 | 
			
		||||
        ConstraintTestBundle{"cannsandraNameCnstraint", validateCassandraName},
 | 
			
		||||
        ConstraintTestBundle{"loadModeConstraint", validateLoadMode},
 | 
			
		||||
        ConstraintTestBundle{"ChannelNameConstraint", validateChannelName},
 | 
			
		||||
        ConstraintTestBundle{"ApiVersionConstraint", validateApiVersion},
 | 
			
		||||
        ConstraintTestBundle{"Uint16Constraint", validateUint16},
 | 
			
		||||
        ConstraintTestBundle{"Uint32Constraint", validateUint32},
 | 
			
		||||
        ConstraintTestBundle{"PositiveDoubleConstraint", validatePositiveDouble}
 | 
			
		||||
    ),
 | 
			
		||||
    [](testing::TestParamInfo<ConstraintTestBundle> const& info) { return info.param.name; }
 | 
			
		||||
);
 | 
			
		||||
 | 
			
		||||
TEST_P(ConstraintDeathTest, TestEachConstraint)
 | 
			
		||||
{
 | 
			
		||||
    EXPECT_DEATH(
 | 
			
		||||
        {
 | 
			
		||||
            [[maybe_unused]] auto const a =
 | 
			
		||||
                ConfigValue{ConfigType::Boolean}.defaultValue(true).withConstraint(GetParam().cons_);
 | 
			
		||||
        },
 | 
			
		||||
        ".*"
 | 
			
		||||
    );
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST(ConfigValueDeathTest, SetInvalidValueTypeStringAndBool)
 | 
			
		||||
{
 | 
			
		||||
    EXPECT_DEATH(
 | 
			
		||||
        {
 | 
			
		||||
            [[maybe_unused]] auto a = ConfigValue{ConfigType::String}.defaultValue(33).withConstraint(validateLoadMode);
 | 
			
		||||
        },
 | 
			
		||||
        ".*"
 | 
			
		||||
    );
 | 
			
		||||
    EXPECT_DEATH({ [[maybe_unused]] auto a = ConfigValue{ConfigType::Boolean}.defaultValue(-66); }, ".*");
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST(ConfigValueDeathTest, OutOfBounceIntegerConstraint)
 | 
			
		||||
{
 | 
			
		||||
    EXPECT_DEATH(
 | 
			
		||||
        {
 | 
			
		||||
            [[maybe_unused]] auto a =
 | 
			
		||||
                ConfigValue{ConfigType::Integer}.defaultValue(999999).withConstraint(validateUint16);
 | 
			
		||||
        },
 | 
			
		||||
        ".*"
 | 
			
		||||
    );
 | 
			
		||||
    EXPECT_DEATH(
 | 
			
		||||
        {
 | 
			
		||||
            [[maybe_unused]] auto a = ConfigValue{ConfigType::Integer}.defaultValue(-66).withConstraint(validateUint32);
 | 
			
		||||
        },
 | 
			
		||||
        ".*"
 | 
			
		||||
    );
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										113
									
								
								tests/unit/util/newconfig/JsonConfigFileTests.cpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										113
									
								
								tests/unit/util/newconfig/JsonConfigFileTests.cpp
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,113 @@
 | 
			
		||||
//------------------------------------------------------------------------------
 | 
			
		||||
/*
 | 
			
		||||
    This file is part of clio: https://github.com/XRPLF/clio
 | 
			
		||||
    Copyright (c) 2024, the clio developers.
 | 
			
		||||
 | 
			
		||||
    Permission to use, copy, modify, and distribute this software for any
 | 
			
		||||
    purpose with or without fee is hereby granted, provided that the above
 | 
			
		||||
    copyright notice and this permission notice appear in all copies.
 | 
			
		||||
 | 
			
		||||
    THE  SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
 | 
			
		||||
    WITH  REGARD  TO  THIS  SOFTWARE  INCLUDING  ALL  IMPLIED  WARRANTIES  OF
 | 
			
		||||
    MERCHANTABILITY  AND  FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
 | 
			
		||||
    ANY  SPECIAL,  DIRECT,  INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
 | 
			
		||||
    WHATSOEVER  RESULTING  FROM  LOSS  OF USE, DATA OR PROFITS, WHETHER IN AN
 | 
			
		||||
    ACTION  OF  CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
 | 
			
		||||
    OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
 | 
			
		||||
*/
 | 
			
		||||
//==============================================================================
 | 
			
		||||
 | 
			
		||||
#include "util/TmpFile.hpp"
 | 
			
		||||
#include "util/newconfig/ConfigFileJson.hpp"
 | 
			
		||||
#include "util/newconfig/FakeConfigData.hpp"
 | 
			
		||||
 | 
			
		||||
#include <boost/json/parse.hpp>
 | 
			
		||||
#include <gtest/gtest.h>
 | 
			
		||||
 | 
			
		||||
#include <algorithm>
 | 
			
		||||
#include <cstdint>
 | 
			
		||||
#include <string>
 | 
			
		||||
#include <variant>
 | 
			
		||||
#include <vector>
 | 
			
		||||
 | 
			
		||||
TEST(CreateConfigFile, filePath)
 | 
			
		||||
{
 | 
			
		||||
    auto const jsonFileObj = ConfigFileJson::make_ConfigFileJson(TmpFile(JSONData).path);
 | 
			
		||||
    EXPECT_TRUE(jsonFileObj.has_value());
 | 
			
		||||
 | 
			
		||||
    EXPECT_TRUE(jsonFileObj->containsKey("array.[].sub"));
 | 
			
		||||
    auto const arrSub = jsonFileObj->getArray("array.[].sub");
 | 
			
		||||
    EXPECT_EQ(arrSub.size(), 3);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST(CreateConfigFile, incorrectFilePath)
 | 
			
		||||
{
 | 
			
		||||
    auto const jsonFileObj = util::config::ConfigFileJson::make_ConfigFileJson("123/clio");
 | 
			
		||||
    EXPECT_FALSE(jsonFileObj.has_value());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
struct ParseJson : testing::Test {
 | 
			
		||||
    ParseJson() : jsonFileObj{boost::json::parse(JSONData).as_object()}
 | 
			
		||||
    {
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    ConfigFileJson const jsonFileObj;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
TEST_F(ParseJson, validateValues)
 | 
			
		||||
{
 | 
			
		||||
    EXPECT_TRUE(jsonFileObj.containsKey("header.text1"));
 | 
			
		||||
    EXPECT_EQ(std::get<std::string>(jsonFileObj.getValue("header.text1")), "value");
 | 
			
		||||
 | 
			
		||||
    EXPECT_TRUE(jsonFileObj.containsKey("header.sub.sub2Value"));
 | 
			
		||||
    EXPECT_EQ(std::get<std::string>(jsonFileObj.getValue("header.sub.sub2Value")), "TSM");
 | 
			
		||||
 | 
			
		||||
    EXPECT_TRUE(jsonFileObj.containsKey("dosguard.port"));
 | 
			
		||||
    EXPECT_EQ(std::get<int64_t>(jsonFileObj.getValue("dosguard.port")), 44444);
 | 
			
		||||
 | 
			
		||||
    EXPECT_FALSE(jsonFileObj.containsKey("idk"));
 | 
			
		||||
    EXPECT_FALSE(jsonFileObj.containsKey("optional.withNoDefault"));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(ParseJson, validateArrayValue)
 | 
			
		||||
{
 | 
			
		||||
    // validate array.[].sub matches expected values
 | 
			
		||||
    EXPECT_TRUE(jsonFileObj.containsKey("array.[].sub"));
 | 
			
		||||
    auto const arrSub = jsonFileObj.getArray("array.[].sub");
 | 
			
		||||
    EXPECT_EQ(arrSub.size(), 3);
 | 
			
		||||
 | 
			
		||||
    std::vector<double> expectedArrSubVal{111.11, 4321.55, 5555.44};
 | 
			
		||||
    std::vector<double> actualArrSubVal{};
 | 
			
		||||
 | 
			
		||||
    for (auto it = arrSub.begin(); it != arrSub.end(); ++it) {
 | 
			
		||||
        ASSERT_TRUE(std::holds_alternative<double>(*it));
 | 
			
		||||
        actualArrSubVal.emplace_back(std::get<double>(*it));
 | 
			
		||||
    }
 | 
			
		||||
    EXPECT_TRUE(std::ranges::equal(expectedArrSubVal, actualArrSubVal));
 | 
			
		||||
 | 
			
		||||
    // validate array.[].sub2 matches expected values
 | 
			
		||||
    EXPECT_TRUE(jsonFileObj.containsKey("array.[].sub2"));
 | 
			
		||||
    auto const arrSub2 = jsonFileObj.getArray("array.[].sub2");
 | 
			
		||||
    EXPECT_EQ(arrSub2.size(), 3);
 | 
			
		||||
    std::vector<std::string> expectedArrSub2Val{"subCategory", "temporary", "london"};
 | 
			
		||||
    std::vector<std::string> actualArrSub2Val{};
 | 
			
		||||
 | 
			
		||||
    for (auto it = arrSub2.begin(); it != arrSub2.end(); ++it) {
 | 
			
		||||
        ASSERT_TRUE(std::holds_alternative<std::string>(*it));
 | 
			
		||||
        actualArrSub2Val.emplace_back(std::get<std::string>(*it));
 | 
			
		||||
    }
 | 
			
		||||
    EXPECT_TRUE(std::ranges::equal(expectedArrSub2Val, actualArrSub2Val));
 | 
			
		||||
 | 
			
		||||
    EXPECT_TRUE(jsonFileObj.containsKey("dosguard.whitelist.[]"));
 | 
			
		||||
    auto const whitelistArr = jsonFileObj.getArray("dosguard.whitelist.[]");
 | 
			
		||||
    EXPECT_EQ(whitelistArr.size(), 2);
 | 
			
		||||
    EXPECT_EQ("125.5.5.1", std::get<std::string>(whitelistArr.at(0)));
 | 
			
		||||
    EXPECT_EQ("204.2.2.1", std::get<std::string>(whitelistArr.at(1)));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
struct JsonValueDeathTest : ParseJson {};
 | 
			
		||||
 | 
			
		||||
TEST_F(JsonValueDeathTest, invalidGetArray)
 | 
			
		||||
{
 | 
			
		||||
    EXPECT_DEATH([[maybe_unused]] auto a = jsonFileObj.getArray("header.text1"), ".*");
 | 
			
		||||
}
 | 
			
		||||
@@ -19,15 +19,23 @@
 | 
			
		||||
 | 
			
		||||
#include "util/newconfig/ArrayView.hpp"
 | 
			
		||||
#include "util/newconfig/ConfigDefinition.hpp"
 | 
			
		||||
#include "util/newconfig/ConfigFileJson.hpp"
 | 
			
		||||
#include "util/newconfig/FakeConfigData.hpp"
 | 
			
		||||
#include "util/newconfig/ObjectView.hpp"
 | 
			
		||||
 | 
			
		||||
#include <boost/json/parse.hpp>
 | 
			
		||||
#include <gtest/gtest.h>
 | 
			
		||||
 | 
			
		||||
using namespace util::config;
 | 
			
		||||
 | 
			
		||||
struct ObjectViewTest : testing::Test {
 | 
			
		||||
    ClioConfigDefinition const configData = generateConfig();
 | 
			
		||||
    ObjectViewTest()
 | 
			
		||||
    {
 | 
			
		||||
        ConfigFileJson const jsonFileObj{boost::json::parse(JSONData).as_object()};
 | 
			
		||||
        auto const errors = configData.parse(jsonFileObj);
 | 
			
		||||
        EXPECT_TRUE(!errors.has_value());
 | 
			
		||||
    }
 | 
			
		||||
    ClioConfigDefinition configData = generateConfig();
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
TEST_F(ObjectViewTest, ObjectValueTest)
 | 
			
		||||
@@ -39,14 +47,14 @@ TEST_F(ObjectViewTest, ObjectValueTest)
 | 
			
		||||
    EXPECT_TRUE(headerObj.containsKey("admin"));
 | 
			
		||||
 | 
			
		||||
    EXPECT_EQ("value", headerObj.getValue("text1").asString());
 | 
			
		||||
    EXPECT_EQ(123, headerObj.getValue("port").asIntType<int>());
 | 
			
		||||
    EXPECT_EQ(true, headerObj.getValue("admin").asBool());
 | 
			
		||||
    EXPECT_EQ(321, headerObj.getValue("port").asIntType<int>());
 | 
			
		||||
    EXPECT_EQ(false, headerObj.getValue("admin").asBool());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(ObjectViewTest, ObjectInArray)
 | 
			
		||||
TEST_F(ObjectViewTest, ObjectValuesInArray)
 | 
			
		||||
{
 | 
			
		||||
    ArrayView const arr = configData.getArray("array");
 | 
			
		||||
    EXPECT_EQ(arr.size(), 2);
 | 
			
		||||
    EXPECT_EQ(arr.size(), 3);
 | 
			
		||||
    ObjectView const firstObj = arr.objectAt(0);
 | 
			
		||||
    ObjectView const secondObj = arr.objectAt(1);
 | 
			
		||||
    EXPECT_TRUE(firstObj.containsKey("sub"));
 | 
			
		||||
@@ -62,7 +70,7 @@ TEST_F(ObjectViewTest, ObjectInArray)
 | 
			
		||||
    EXPECT_EQ(secondObj.getValue("sub2").asString(), "temporary");
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(ObjectViewTest, ObjectInArrayMoreComplex)
 | 
			
		||||
TEST_F(ObjectViewTest, GetObjectsInDifferentWays)
 | 
			
		||||
{
 | 
			
		||||
    ArrayView const arr = configData.getArray("higher");
 | 
			
		||||
    ASSERT_EQ(1, arr.size());
 | 
			
		||||
@@ -78,7 +86,7 @@ TEST_F(ObjectViewTest, ObjectInArrayMoreComplex)
 | 
			
		||||
    ObjectView const objLow = firstObj.getObject("low");
 | 
			
		||||
    EXPECT_TRUE(objLow.containsKey("section"));
 | 
			
		||||
    EXPECT_TRUE(objLow.containsKey("admin"));
 | 
			
		||||
    EXPECT_EQ(objLow.getValue("section").asString(), "true");
 | 
			
		||||
    EXPECT_EQ(objLow.getValue("section").asString(), "WebServer");
 | 
			
		||||
    EXPECT_EQ(objLow.getValue("admin").asBool(), false);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -90,18 +98,25 @@ TEST_F(ObjectViewTest, getArrayInObject)
 | 
			
		||||
    auto const arr = obj.getArray("whitelist");
 | 
			
		||||
    EXPECT_EQ(2, arr.size());
 | 
			
		||||
 | 
			
		||||
    EXPECT_EQ("125.5.5.2", arr.valueAt(0).asString());
 | 
			
		||||
    EXPECT_EQ("204.2.2.2", arr.valueAt(1).asString());
 | 
			
		||||
    EXPECT_EQ("125.5.5.1", arr.valueAt(0).asString());
 | 
			
		||||
    EXPECT_EQ("204.2.2.1", arr.valueAt(1).asString());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
struct ObjectViewDeathTest : ObjectViewTest {};
 | 
			
		||||
 | 
			
		||||
TEST_F(ObjectViewDeathTest, incorrectKeys)
 | 
			
		||||
TEST_F(ObjectViewDeathTest, KeyDoesNotExist)
 | 
			
		||||
{
 | 
			
		||||
    EXPECT_DEATH({ [[maybe_unused]] auto _ = configData.getObject("head"); }, ".*");
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(ObjectViewDeathTest, KeyIsValueView)
 | 
			
		||||
{
 | 
			
		||||
    EXPECT_DEATH({ [[maybe_unused]] auto _ = configData.getObject("header.text1"); }, ".*");
 | 
			
		||||
    EXPECT_DEATH({ [[maybe_unused]] auto _ = configData.getObject("head"); }, ".*");
 | 
			
		||||
    EXPECT_DEATH({ [[maybe_unused]] auto _ = configData.getArray("header"); }, ".*");
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(ObjectViewDeathTest, KeyisArrayView)
 | 
			
		||||
{
 | 
			
		||||
    // dies because only 1 object in higher.[].low
 | 
			
		||||
    EXPECT_DEATH({ [[maybe_unused]] auto _ = configData.getObject("higher.[].low", 1); }, ".*");
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -20,6 +20,7 @@
 | 
			
		||||
#include "util/newconfig/ConfigDefinition.hpp"
 | 
			
		||||
#include "util/newconfig/ConfigValue.hpp"
 | 
			
		||||
#include "util/newconfig/FakeConfigData.hpp"
 | 
			
		||||
#include "util/newconfig/Types.hpp"
 | 
			
		||||
#include "util/newconfig/ValueView.hpp"
 | 
			
		||||
 | 
			
		||||
#include <gtest/gtest.h>
 | 
			
		||||
 
 | 
			
		||||
@@ -25,6 +25,8 @@
 | 
			
		||||
#include <boost/asio/error.hpp>
 | 
			
		||||
#include <boost/asio/spawn.hpp>
 | 
			
		||||
#include <boost/beast/http/field.hpp>
 | 
			
		||||
#include <boost/beast/websocket/stream.hpp>
 | 
			
		||||
#include <gmock/gmock.h>
 | 
			
		||||
#include <gtest/gtest.h>
 | 
			
		||||
 | 
			
		||||
#include <chrono>
 | 
			
		||||
@@ -33,6 +35,7 @@
 | 
			
		||||
#include <memory>
 | 
			
		||||
#include <optional>
 | 
			
		||||
#include <string>
 | 
			
		||||
#include <string_view>
 | 
			
		||||
#include <thread>
 | 
			
		||||
#include <vector>
 | 
			
		||||
 | 
			
		||||
@@ -318,6 +321,40 @@ TEST_F(WsConnectionTests, MultipleConnections)
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEST_F(WsConnectionTests, RespondsToPing)
 | 
			
		||||
{
 | 
			
		||||
    asio::spawn(ctx, [&](asio::yield_context yield) {
 | 
			
		||||
        auto serverConnection = unwrap(server.acceptConnection(yield));
 | 
			
		||||
 | 
			
		||||
        testing::StrictMock<testing::MockFunction<void(boost::beast::websocket::frame_type, std::string_view)>>
 | 
			
		||||
            controlFrameCallback;
 | 
			
		||||
        serverConnection.setControlFrameCallback(controlFrameCallback.AsStdFunction());
 | 
			
		||||
        EXPECT_CALL(controlFrameCallback, Call(boost::beast::websocket::frame_type::pong, testing::_)).WillOnce([&]() {
 | 
			
		||||
            serverConnection.resetControlFrameCallback();
 | 
			
		||||
            asio::spawn(ctx, [&](asio::yield_context yield) {
 | 
			
		||||
                auto maybeError = serverConnection.send("got pong", yield);
 | 
			
		||||
                ASSERT_FALSE(maybeError.has_value()) << *maybeError;
 | 
			
		||||
            });
 | 
			
		||||
        });
 | 
			
		||||
 | 
			
		||||
        serverConnection.sendPing({}, yield);
 | 
			
		||||
        auto message = serverConnection.receive(yield);
 | 
			
		||||
        ASSERT_TRUE(message.has_value());
 | 
			
		||||
        EXPECT_EQ(message, "hello") << message.value();
 | 
			
		||||
    });
 | 
			
		||||
 | 
			
		||||
    runSpawn([&](asio::yield_context yield) {
 | 
			
		||||
        auto connection = builder.plainConnect(yield);
 | 
			
		||||
        ASSERT_TRUE(connection.has_value()) << connection.error().message();
 | 
			
		||||
        auto expectedMessage = connection->operator*().read(yield);
 | 
			
		||||
        ASSERT_TRUE(expectedMessage) << expectedMessage.error().message();
 | 
			
		||||
        EXPECT_EQ(expectedMessage.value(), "got pong");
 | 
			
		||||
 | 
			
		||||
        auto error = connection->operator*().write("hello", yield);
 | 
			
		||||
        ASSERT_FALSE(error) << error->message();
 | 
			
		||||
    });
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
enum class WsConnectionErrorTestsBundle : int { Read = 1, Write = 2 };
 | 
			
		||||
 | 
			
		||||
struct WsConnectionErrorTests : WsConnectionTestsBase, testing::WithParamInterface<WsConnectionErrorTestsBundle> {};
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										210
									
								
								tools/cassandra_delete_range/cassandra_delete_range.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										210
									
								
								tools/cassandra_delete_range/cassandra_delete_range.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,210 @@
 | 
			
		||||
//
 | 
			
		||||
// Based off of https://github.com/scylladb/scylla-code-samples/blob/master/efficient_full_table_scan_example_code/efficient_full_table_scan.go
 | 
			
		||||
//
 | 
			
		||||
 | 
			
		||||
package main
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"log"
 | 
			
		||||
	"os"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
	"xrplf/clio/cassandra_delete_range/internal/cass"
 | 
			
		||||
	"xrplf/clio/cassandra_delete_range/internal/util"
 | 
			
		||||
 | 
			
		||||
	"github.com/alecthomas/kingpin/v2"
 | 
			
		||||
	"github.com/gocql/gocql"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	defaultNumberOfNodesInCluster = 3
 | 
			
		||||
	defaultNumberOfCoresInNode    = 8
 | 
			
		||||
	defaultSmudgeFactor           = 3
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var (
 | 
			
		||||
	app   = kingpin.New("cassandra_delete_range", "A tool that prunes data from the Clio DB.")
 | 
			
		||||
	hosts = app.Flag("hosts", "Your Scylla nodes IP addresses, comma separated (i.e. 192.168.1.1,192.168.1.2,192.168.1.3)").Required().String()
 | 
			
		||||
 | 
			
		||||
	deleteAfter          = app.Command("delete-after", "Prunes from the given ledger index until the end")
 | 
			
		||||
	deleteAfterLedgerIdx = deleteAfter.Arg("idx", "Sets the earliest ledger_index to keep untouched (delete everything after this ledger index)").Required().Uint64()
 | 
			
		||||
 | 
			
		||||
	deleteBefore          = app.Command("delete-before", "Prunes everything before the given ledger index")
 | 
			
		||||
	deleteBeforeLedgerIdx = deleteBefore.Arg("idx", "Sets the latest ledger_index to keep around (delete everything before this ledger index)").Required().Uint64()
 | 
			
		||||
 | 
			
		||||
	getLedgerRange = app.Command("get-ledger-range", "Fetch the current lender_range table values")
 | 
			
		||||
 | 
			
		||||
	nodesInCluster        = app.Flag("nodes-in-cluster", "Number of nodes in your Scylla cluster").Short('n').Default(fmt.Sprintf("%d", defaultNumberOfNodesInCluster)).Int()
 | 
			
		||||
	coresInNode           = app.Flag("cores-in-node", "Number of cores in each node").Short('c').Default(fmt.Sprintf("%d", defaultNumberOfCoresInNode)).Int()
 | 
			
		||||
	smudgeFactor          = app.Flag("smudge-factor", "Yet another factor to make parallelism cooler").Short('s').Default(fmt.Sprintf("%d", defaultSmudgeFactor)).Int()
 | 
			
		||||
	clusterConsistency    = app.Flag("consistency", "Cluster consistency level. Use 'localone' for multi DC").Short('o').Default("localquorum").String()
 | 
			
		||||
	clusterTimeout        = app.Flag("timeout", "Maximum duration for query execution in millisecond").Short('t').Default("15000").Int()
 | 
			
		||||
	clusterNumConnections = app.Flag("cluster-number-of-connections", "Number of connections per host per session (in our case, per thread)").Short('b').Default("1").Int()
 | 
			
		||||
	clusterCQLVersion     = app.Flag("cql-version", "The CQL version to use").Short('l').Default("3.0.0").String()
 | 
			
		||||
	clusterPageSize       = app.Flag("cluster-page-size", "Page size of results").Short('p').Default("5000").Int()
 | 
			
		||||
	keyspace              = app.Flag("keyspace", "Keyspace to use").Short('k').Default("clio_fh").String()
 | 
			
		||||
 | 
			
		||||
	userName = app.Flag("username", "Username to use when connecting to the cluster").String()
 | 
			
		||||
	password = app.Flag("password", "Password to use when connecting to the cluster").String()
 | 
			
		||||
 | 
			
		||||
	skipSuccessorTable          = app.Flag("skip-successor", "Whether to skip deletion from successor table").Default("false").Bool()
 | 
			
		||||
	skipObjectsTable            = app.Flag("skip-objects", "Whether to skip deletion from objects table").Default("false").Bool()
 | 
			
		||||
	skipLedgerHashesTable       = app.Flag("skip-ledger-hashes", "Whether to skip deletion from ledger_hashes table").Default("false").Bool()
 | 
			
		||||
	skipTransactionsTable       = app.Flag("skip-transactions", "Whether to skip deletion from transactions table").Default("false").Bool()
 | 
			
		||||
	skipDiffTable               = app.Flag("skip-diff", "Whether to skip deletion from diff table").Default("false").Bool()
 | 
			
		||||
	skipLedgerTransactionsTable = app.Flag("skip-ledger-transactions", "Whether to skip deletion from ledger_transactions table").Default("false").Bool()
 | 
			
		||||
	skipLedgersTable            = app.Flag("skip-ledgers", "Whether to skip deletion from ledgers table").Default("false").Bool()
 | 
			
		||||
	skipWriteLatestLedger       = app.Flag("skip-write-latest-ledger", "Whether to skip writing the latest ledger index").Default("false").Bool()
 | 
			
		||||
 | 
			
		||||
	workerCount = 1                // the calculated number of parallel goroutines the client should run
 | 
			
		||||
	ranges      []*util.TokenRange // the calculated ranges to be executed in parallel
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func main() {
 | 
			
		||||
	log.SetOutput(os.Stdout)
 | 
			
		||||
 | 
			
		||||
	command := kingpin.MustParse(app.Parse(os.Args[1:]))
 | 
			
		||||
	cluster, err := prepareDb(hosts)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		log.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	clioCass := cass.NewClioCass(&cass.Settings{
 | 
			
		||||
		SkipSuccessorTable:          *skipSuccessorTable,
 | 
			
		||||
		SkipObjectsTable:            *skipObjectsTable,
 | 
			
		||||
		SkipLedgerHashesTable:       *skipLedgerHashesTable,
 | 
			
		||||
		SkipTransactionsTable:       *skipTransactionsTable,
 | 
			
		||||
		SkipDiffTable:               *skipDiffTable,
 | 
			
		||||
		SkipLedgerTransactionsTable: *skipLedgerHashesTable,
 | 
			
		||||
		SkipLedgersTable:            *skipLedgersTable,
 | 
			
		||||
		SkipWriteLatestLedger:       *skipWriteLatestLedger,
 | 
			
		||||
		WorkerCount:                 workerCount,
 | 
			
		||||
		Ranges:                      ranges}, cluster)
 | 
			
		||||
 | 
			
		||||
	switch command {
 | 
			
		||||
	case deleteAfter.FullCommand():
 | 
			
		||||
		if *deleteAfterLedgerIdx == 0 {
 | 
			
		||||
			log.Println("Please specify ledger index to delete from")
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		displayParams("delete-after", hosts, cluster.Timeout/1000/1000, *deleteAfterLedgerIdx)
 | 
			
		||||
		log.Printf("Will delete everything after ledger index %d (exclusive) and till latest\n", *deleteAfterLedgerIdx)
 | 
			
		||||
		log.Println("WARNING: Please make sure that there are no Clio writers operating on the DB while this script is running")
 | 
			
		||||
 | 
			
		||||
		if !util.PromptContinue() {
 | 
			
		||||
			log.Fatal("Aborted")
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		startTime := time.Now().UTC()
 | 
			
		||||
		clioCass.DeleteAfter(*deleteAfterLedgerIdx)
 | 
			
		||||
 | 
			
		||||
		fmt.Printf("Total Execution Time: %s\n\n", time.Since(startTime))
 | 
			
		||||
		fmt.Println("NOTE: Cassandra/ScyllaDB only writes tombstones. You need to run compaction to free up disk space.")
 | 
			
		||||
 | 
			
		||||
	case deleteBefore.FullCommand():
 | 
			
		||||
		if *deleteBeforeLedgerIdx == 0 {
 | 
			
		||||
			log.Println("Please specify ledger index to delete until")
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		displayParams("delete-before", hosts, cluster.Timeout/1000/1000, *deleteBeforeLedgerIdx)
 | 
			
		||||
		log.Printf("Will delete everything before ledger index %d (exclusive)\n", *deleteBeforeLedgerIdx)
 | 
			
		||||
		log.Println("WARNING: Please make sure that there are no Clio writers operating on the DB while this script is running")
 | 
			
		||||
 | 
			
		||||
		if !util.PromptContinue() {
 | 
			
		||||
			log.Fatal("Aborted")
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		startTime := time.Now().UTC()
 | 
			
		||||
		clioCass.DeleteBefore(*deleteBeforeLedgerIdx)
 | 
			
		||||
 | 
			
		||||
		fmt.Printf("Total Execution Time: %s\n\n", time.Since(startTime))
 | 
			
		||||
		fmt.Println("NOTE: Cassandra/ScyllaDB only writes tombstones. You need to run compaction to free up disk space.")
 | 
			
		||||
	case getLedgerRange.FullCommand():
 | 
			
		||||
		from, to, err := clioCass.GetLedgerRange()
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			log.Fatal(err)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		fmt.Printf("Range: %d -> %d\n", from, to)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func displayParams(command string, hosts *string, timeout time.Duration, ledgerIdx uint64) {
 | 
			
		||||
	runParameters := fmt.Sprintf(`
 | 
			
		||||
Execution Parameters:
 | 
			
		||||
=====================
 | 
			
		||||
 | 
			
		||||
Command                       : %s
 | 
			
		||||
Ledger index                  : %d
 | 
			
		||||
Scylla cluster nodes          : %s
 | 
			
		||||
Keyspace                      : %s
 | 
			
		||||
Consistency                   : %s
 | 
			
		||||
Timeout (ms)                  : %d
 | 
			
		||||
Connections per host          : %d
 | 
			
		||||
CQL Version                   : %s
 | 
			
		||||
Page size                     : %d
 | 
			
		||||
# of parallel threads         : %d
 | 
			
		||||
# of ranges to be executed    : %d
 | 
			
		||||
 | 
			
		||||
Skip deletion of:
 | 
			
		||||
- successor table             : %t
 | 
			
		||||
- objects table               : %t
 | 
			
		||||
- ledger_hashes table         : %t
 | 
			
		||||
- transactions table          : %t
 | 
			
		||||
- diff table                  : %t
 | 
			
		||||
- ledger_transactions table   : %t
 | 
			
		||||
- ledgers table               : %t
 | 
			
		||||
 | 
			
		||||
Will update ledger_range      : %t
 | 
			
		||||
 | 
			
		||||
`,
 | 
			
		||||
		command,
 | 
			
		||||
		ledgerIdx,
 | 
			
		||||
		*hosts,
 | 
			
		||||
		*keyspace,
 | 
			
		||||
		*clusterConsistency,
 | 
			
		||||
		timeout,
 | 
			
		||||
		*clusterNumConnections,
 | 
			
		||||
		*clusterCQLVersion,
 | 
			
		||||
		*clusterPageSize,
 | 
			
		||||
		workerCount,
 | 
			
		||||
		len(ranges),
 | 
			
		||||
		*skipSuccessorTable || command == "delete-before",
 | 
			
		||||
		*skipObjectsTable,
 | 
			
		||||
		*skipLedgerHashesTable,
 | 
			
		||||
		*skipTransactionsTable,
 | 
			
		||||
		*skipDiffTable,
 | 
			
		||||
		*skipLedgerTransactionsTable,
 | 
			
		||||
		*skipLedgersTable,
 | 
			
		||||
		!*skipWriteLatestLedger)
 | 
			
		||||
 | 
			
		||||
	fmt.Println(runParameters)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func prepareDb(dbHosts *string) (*gocql.ClusterConfig, error) {
 | 
			
		||||
	workerCount = (*nodesInCluster) * (*coresInNode) * (*smudgeFactor)
 | 
			
		||||
	ranges = util.GetTokenRanges(workerCount)
 | 
			
		||||
	util.Shuffle(ranges)
 | 
			
		||||
 | 
			
		||||
	hosts := strings.Split(*dbHosts, ",")
 | 
			
		||||
 | 
			
		||||
	cluster := gocql.NewCluster(hosts...)
 | 
			
		||||
	cluster.Consistency = util.GetConsistencyLevel(*clusterConsistency)
 | 
			
		||||
	cluster.Timeout = time.Duration(*clusterTimeout * 1000 * 1000)
 | 
			
		||||
	cluster.NumConns = *clusterNumConnections
 | 
			
		||||
	cluster.CQLVersion = *clusterCQLVersion
 | 
			
		||||
	cluster.PageSize = *clusterPageSize
 | 
			
		||||
	cluster.Keyspace = *keyspace
 | 
			
		||||
 | 
			
		||||
	if *userName != "" {
 | 
			
		||||
		cluster.Authenticator = gocql.PasswordAuthenticator{
 | 
			
		||||
			Username: *userName,
 | 
			
		||||
			Password: *password,
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return cluster, nil
 | 
			
		||||
}
 | 
			
		||||
@@ -3,9 +3,13 @@ module xrplf/clio/cassandra_delete_range
 | 
			
		||||
go 1.21.6
 | 
			
		||||
 | 
			
		||||
require (
 | 
			
		||||
	github.com/alecthomas/kingpin/v2 v2.4.0 // indirect
 | 
			
		||||
	github.com/alecthomas/kingpin/v2 v2.4.0
 | 
			
		||||
	github.com/gocql/gocql v1.6.0
 | 
			
		||||
	github.com/pmorelli92/maybe v1.1.0
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
require (
 | 
			
		||||
	github.com/alecthomas/units v0.0.0-20211218093645-b94a6e3cc137 // indirect
 | 
			
		||||
	github.com/gocql/gocql v1.6.0 // indirect
 | 
			
		||||
	github.com/golang/snappy v0.0.3 // indirect
 | 
			
		||||
	github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed // indirect
 | 
			
		||||
	github.com/xhit/go-str2duration/v2 v2.1.0 // indirect
 | 
			
		||||
 
 | 
			
		||||
@@ -2,25 +2,38 @@ github.com/alecthomas/kingpin/v2 v2.4.0 h1:f48lwail6p8zpO1bC4TxtqACaGqHYA22qkHjH
 | 
			
		||||
github.com/alecthomas/kingpin/v2 v2.4.0/go.mod h1:0gyi0zQnjuFk8xrkNKamJoyUo382HRL7ATRpFZCw6tE=
 | 
			
		||||
github.com/alecthomas/units v0.0.0-20211218093645-b94a6e3cc137 h1:s6gZFSlWYmbqAuRjVTiNNhvNRfY2Wxp9nhfyel4rklc=
 | 
			
		||||
github.com/alecthomas/units v0.0.0-20211218093645-b94a6e3cc137/go.mod h1:OMCwj8VM1Kc9e19TLln2VL61YJF0x1XFtfdL4JdbSyE=
 | 
			
		||||
github.com/bitly/go-hostpool v0.0.0-20171023180738-a3a6125de932 h1:mXoPYz/Ul5HYEDvkta6I8/rnYM5gSdSV2tJ6XbZuEtY=
 | 
			
		||||
github.com/bitly/go-hostpool v0.0.0-20171023180738-a3a6125de932/go.mod h1:NOuUCSz6Q9T7+igc/hlvDOUdtWKryOrtFyIVABv/p7k=
 | 
			
		||||
github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869 h1:DDGfHa7BWjL4YnC6+E63dPcxHo2sUxDIu8g3QgEJdRY=
 | 
			
		||||
github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869/go.mod h1:Ekp36dRnpXw/yCqJaO+ZrUyxD+3VXMFFr56k5XYrpB4=
 | 
			
		||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
 | 
			
		||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
 | 
			
		||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
 | 
			
		||||
github.com/gocql/gocql v1.6.0 h1:IdFdOTbnpbd0pDhl4REKQDM+Q0SzKXQ1Yh+YZZ8T/qU=
 | 
			
		||||
github.com/gocql/gocql v1.6.0/go.mod h1:3gM2c4D3AnkISwBxGnMMsS8Oy4y2lhbPRsH4xnJrHG8=
 | 
			
		||||
github.com/golang/snappy v0.0.3 h1:fHPg5GQYlCeLIPB9BZqMVR5nR9A+IM5zcgeTdjMYmLA=
 | 
			
		||||
github.com/golang/snappy v0.0.3/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
 | 
			
		||||
github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed h1:5upAirOpQc1Q53c0bnx2ufif5kANL7bfZWcc6VJWJd8=
 | 
			
		||||
github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed/go.mod h1:tMWxXQ9wFIaZeTI9F+hmhFiGpFmhOHzyShyFUhRm0H4=
 | 
			
		||||
github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI=
 | 
			
		||||
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
 | 
			
		||||
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
 | 
			
		||||
github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
 | 
			
		||||
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
 | 
			
		||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
 | 
			
		||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
 | 
			
		||||
github.com/pmorelli92/maybe v1.1.0 h1:uyV6NLF4453AQARZ6rKpJNzc9PBsQmpGDtUonhxInPU=
 | 
			
		||||
github.com/pmorelli92/maybe v1.1.0/go.mod h1:5PrW2+fo4/j/LMX6HT49Hb3/HOKv1tbodkzgy4lEopA=
 | 
			
		||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
 | 
			
		||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
 | 
			
		||||
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
 | 
			
		||||
github.com/stretchr/testify v1.8.2 h1:+h33VjcLVPDHtOdpUCuF+7gSuG3yGIftsP1YvFihtJ8=
 | 
			
		||||
github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
 | 
			
		||||
github.com/xhit/go-str2duration/v2 v2.1.0 h1:lxklc02Drh6ynqX+DdPyp5pCKLUQpRT8bp8Ydu2Bstc=
 | 
			
		||||
github.com/xhit/go-str2duration/v2 v2.1.0/go.mod h1:ohY8p+0f07DiV6Em5LKB0s2YpLtXVyJfNt1+BlmyAsU=
 | 
			
		||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
 | 
			
		||||
gopkg.in/inf.v0 v0.9.1 h1:73M5CoZyi3ZLMOyDlQh031Cx6N9NDJ2Vvfl76EDAgDc=
 | 
			
		||||
gopkg.in/inf.v0 v0.9.1/go.mod h1:cWUDdTG/fYaXco+Dcufb5Vnc6Gp2YChqWtbxRZE0mXw=
 | 
			
		||||
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
 | 
			
		||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
 | 
			
		||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										564
									
								
								tools/cassandra_delete_range/internal/cass/cass.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										564
									
								
								tools/cassandra_delete_range/internal/cass/cass.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,564 @@
 | 
			
		||||
package cass
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"log"
 | 
			
		||||
	"os"
 | 
			
		||||
	"slices"
 | 
			
		||||
	"strconv"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"sync"
 | 
			
		||||
	"sync/atomic"
 | 
			
		||||
	"xrplf/clio/cassandra_delete_range/internal/util"
 | 
			
		||||
 | 
			
		||||
	"github.com/gocql/gocql"
 | 
			
		||||
	"github.com/pmorelli92/maybe"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type deleteInfo struct {
 | 
			
		||||
	Query string
 | 
			
		||||
	Data  []deleteParams
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type deleteParams struct {
 | 
			
		||||
	Seq  uint64
 | 
			
		||||
	Blob []byte // hash, key, etc
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type columnSettings struct {
 | 
			
		||||
	UseSeq  bool
 | 
			
		||||
	UseBlob bool
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Settings struct {
 | 
			
		||||
	SkipSuccessorTable          bool
 | 
			
		||||
	SkipObjectsTable            bool
 | 
			
		||||
	SkipLedgerHashesTable       bool
 | 
			
		||||
	SkipTransactionsTable       bool
 | 
			
		||||
	SkipDiffTable               bool
 | 
			
		||||
	SkipLedgerTransactionsTable bool
 | 
			
		||||
	SkipLedgersTable            bool
 | 
			
		||||
	SkipWriteLatestLedger       bool
 | 
			
		||||
 | 
			
		||||
	WorkerCount int
 | 
			
		||||
	Ranges      []*util.TokenRange
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Cass interface {
 | 
			
		||||
	GetLedgerRange() (uint64, uint64, error)
 | 
			
		||||
	DeleteBefore(ledgerIdx uint64)
 | 
			
		||||
	DeleteAfter(ledgerIdx uint64)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type ClioCass struct {
 | 
			
		||||
	settings      *Settings
 | 
			
		||||
	clusterConfig *gocql.ClusterConfig
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewClioCass(settings *Settings, cluster *gocql.ClusterConfig) *ClioCass {
 | 
			
		||||
	return &ClioCass{settings, cluster}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *ClioCass) DeleteBefore(ledgerIdx uint64) {
 | 
			
		||||
	firstLedgerIdxInDB, latestLedgerIdxInDB, err := c.GetLedgerRange()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		log.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	log.Printf("DB ledger range is %d -> %d\n", firstLedgerIdxInDB, latestLedgerIdxInDB)
 | 
			
		||||
 | 
			
		||||
	if firstLedgerIdxInDB > ledgerIdx {
 | 
			
		||||
		log.Fatal("Earliest ledger index in DB is greater than the one specified. Aborting...")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if latestLedgerIdxInDB < ledgerIdx {
 | 
			
		||||
		log.Fatal("Latest ledger index in DB is smaller than the one specified. Aborting...")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var (
 | 
			
		||||
		from maybe.Maybe[uint64] // not used
 | 
			
		||||
		to   maybe.Maybe[uint64] = maybe.Set(ledgerIdx - 1)
 | 
			
		||||
	)
 | 
			
		||||
 | 
			
		||||
	c.settings.SkipSuccessorTable = true // skip successor update until we know how to do it
 | 
			
		||||
	if err := c.pruneData(from, to, firstLedgerIdxInDB, latestLedgerIdxInDB); err != nil {
 | 
			
		||||
		log.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *ClioCass) DeleteAfter(ledgerIdx uint64) {
 | 
			
		||||
	firstLedgerIdxInDB, latestLedgerIdxInDB, err := c.GetLedgerRange()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		log.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	log.Printf("DB ledger range is %d -> %d\n", firstLedgerIdxInDB, latestLedgerIdxInDB)
 | 
			
		||||
 | 
			
		||||
	if firstLedgerIdxInDB > ledgerIdx {
 | 
			
		||||
		log.Fatal("Earliest ledger index in DB is greater than the one specified. Aborting...")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if latestLedgerIdxInDB < ledgerIdx {
 | 
			
		||||
		log.Fatal("Latest ledger index in DB is smaller than the one specified. Aborting...")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var (
 | 
			
		||||
		from maybe.Maybe[uint64] = maybe.Set(ledgerIdx + 1)
 | 
			
		||||
		to   maybe.Maybe[uint64] // not used
 | 
			
		||||
	)
 | 
			
		||||
 | 
			
		||||
	if err := c.pruneData(from, to, firstLedgerIdxInDB, latestLedgerIdxInDB); err != nil {
 | 
			
		||||
		log.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *ClioCass) GetLedgerRange() (uint64, uint64, error) {
 | 
			
		||||
	var (
 | 
			
		||||
		firstLedgerIdx  uint64
 | 
			
		||||
		latestLedgerIdx uint64
 | 
			
		||||
	)
 | 
			
		||||
 | 
			
		||||
	session, err := c.clusterConfig.CreateSession()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		log.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	defer session.Close()
 | 
			
		||||
 | 
			
		||||
	if err := session.Query("SELECT sequence FROM ledger_range WHERE is_latest = ?", false).Scan(&firstLedgerIdx); err != nil {
 | 
			
		||||
		return 0, 0, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err := session.Query("SELECT sequence FROM ledger_range WHERE is_latest = ?", true).Scan(&latestLedgerIdx); err != nil {
 | 
			
		||||
		return 0, 0, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return firstLedgerIdx, latestLedgerIdx, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *ClioCass) pruneData(
 | 
			
		||||
	fromLedgerIdx maybe.Maybe[uint64],
 | 
			
		||||
	toLedgerIdx maybe.Maybe[uint64],
 | 
			
		||||
	firstLedgerIdxInDB uint64,
 | 
			
		||||
	latestLedgerIdxInDB uint64,
 | 
			
		||||
) error {
 | 
			
		||||
	var totalErrors uint64
 | 
			
		||||
	var totalRows uint64
 | 
			
		||||
	var totalDeletes uint64
 | 
			
		||||
 | 
			
		||||
	var info deleteInfo
 | 
			
		||||
	var rowsCount uint64
 | 
			
		||||
	var deleteCount uint64
 | 
			
		||||
	var errCount uint64
 | 
			
		||||
 | 
			
		||||
	// calculate range of simple delete queries
 | 
			
		||||
	var (
 | 
			
		||||
		rangeFrom uint64 = firstLedgerIdxInDB
 | 
			
		||||
		rangeTo   uint64 = latestLedgerIdxInDB
 | 
			
		||||
	)
 | 
			
		||||
 | 
			
		||||
	if fromLedgerIdx.HasValue() {
 | 
			
		||||
		rangeFrom = fromLedgerIdx.Value()
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if toLedgerIdx.HasValue() {
 | 
			
		||||
		rangeTo = toLedgerIdx.Value()
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// calculate and print deletion plan
 | 
			
		||||
	fromStr := "beginning"
 | 
			
		||||
	if fromLedgerIdx.HasValue() {
 | 
			
		||||
		fromStr = strconv.Itoa(int(fromLedgerIdx.Value()))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	toStr := "latest"
 | 
			
		||||
	if toLedgerIdx.HasValue() {
 | 
			
		||||
		toStr = strconv.Itoa(int(toLedgerIdx.Value()))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	log.Printf("Start scanning and removing data for %s -> %s\n\n", fromStr, toStr)
 | 
			
		||||
 | 
			
		||||
	// successor queries
 | 
			
		||||
	if !c.settings.SkipSuccessorTable {
 | 
			
		||||
		log.Println("Generating delete queries for successor table")
 | 
			
		||||
		info, rowsCount, errCount = c.prepareDeleteQueries(fromLedgerIdx, toLedgerIdx,
 | 
			
		||||
			"SELECT key, seq FROM successor WHERE token(key) >= ? AND token(key) <= ?",
 | 
			
		||||
			"DELETE FROM successor WHERE key = ? AND seq = ?", false)
 | 
			
		||||
		log.Printf("Total delete queries: %d\n", len(info.Data))
 | 
			
		||||
		log.Printf("Total traversed rows: %d\n\n", rowsCount)
 | 
			
		||||
		totalErrors += errCount
 | 
			
		||||
		totalRows += rowsCount
 | 
			
		||||
		deleteCount, errCount = c.performDeleteQueries(&info, columnSettings{UseBlob: true, UseSeq: true})
 | 
			
		||||
		totalErrors += errCount
 | 
			
		||||
		totalDeletes += deleteCount
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// objects queries
 | 
			
		||||
	if !c.settings.SkipObjectsTable {
 | 
			
		||||
		log.Println("Generating delete queries for objects table")
 | 
			
		||||
		info, rowsCount, errCount = c.prepareDeleteQueries(fromLedgerIdx, toLedgerIdx,
 | 
			
		||||
			"SELECT key, sequence FROM objects WHERE token(key) >= ? AND token(key) <= ?",
 | 
			
		||||
			"DELETE FROM objects WHERE key = ? AND sequence = ?", true)
 | 
			
		||||
		log.Printf("Total delete queries: %d\n", len(info.Data))
 | 
			
		||||
		log.Printf("Total traversed rows: %d\n\n", rowsCount)
 | 
			
		||||
		totalErrors += errCount
 | 
			
		||||
		totalRows += rowsCount
 | 
			
		||||
		deleteCount, errCount = c.performDeleteQueries(&info, columnSettings{UseBlob: true, UseSeq: true})
 | 
			
		||||
		totalErrors += errCount
 | 
			
		||||
		totalDeletes += deleteCount
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// ledger_hashes queries
 | 
			
		||||
	if !c.settings.SkipLedgerHashesTable {
 | 
			
		||||
		log.Println("Generating delete queries for ledger_hashes table")
 | 
			
		||||
		info, rowsCount, errCount = c.prepareDeleteQueries(fromLedgerIdx, toLedgerIdx,
 | 
			
		||||
			"SELECT hash, sequence FROM ledger_hashes WHERE token(hash) >= ? AND token(hash) <= ?",
 | 
			
		||||
			"DELETE FROM ledger_hashes WHERE hash = ?", false)
 | 
			
		||||
		log.Printf("Total delete queries: %d\n", len(info.Data))
 | 
			
		||||
		log.Printf("Total traversed rows: %d\n\n", rowsCount)
 | 
			
		||||
		totalErrors += errCount
 | 
			
		||||
		totalRows += rowsCount
 | 
			
		||||
		deleteCount, errCount = c.performDeleteQueries(&info, columnSettings{UseBlob: true, UseSeq: false})
 | 
			
		||||
		totalErrors += errCount
 | 
			
		||||
		totalDeletes += deleteCount
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// transactions queries
 | 
			
		||||
	if !c.settings.SkipTransactionsTable {
 | 
			
		||||
		log.Println("Generating delete queries for transactions table")
 | 
			
		||||
		info, rowsCount, errCount = c.prepareDeleteQueries(fromLedgerIdx, toLedgerIdx,
 | 
			
		||||
			"SELECT hash, ledger_sequence FROM transactions WHERE token(hash) >= ? AND token(hash) <= ?",
 | 
			
		||||
			"DELETE FROM transactions WHERE hash = ?", false)
 | 
			
		||||
		log.Printf("Total delete queries: %d\n", len(info.Data))
 | 
			
		||||
		log.Printf("Total traversed rows: %d\n\n", rowsCount)
 | 
			
		||||
		totalErrors += errCount
 | 
			
		||||
		totalRows += rowsCount
 | 
			
		||||
		deleteCount, errCount = c.performDeleteQueries(&info, columnSettings{UseBlob: true, UseSeq: false})
 | 
			
		||||
		totalErrors += errCount
 | 
			
		||||
		totalDeletes += deleteCount
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// diff queries
 | 
			
		||||
	if !c.settings.SkipDiffTable {
 | 
			
		||||
		log.Println("Generating delete queries for diff table")
 | 
			
		||||
		info = c.prepareSimpleDeleteQueries(rangeFrom, rangeTo,
 | 
			
		||||
			"DELETE FROM diff WHERE seq = ?")
 | 
			
		||||
		log.Printf("Total delete queries: %d\n\n", len(info.Data))
 | 
			
		||||
		deleteCount, errCount = c.performDeleteQueries(&info, columnSettings{UseBlob: false, UseSeq: true})
 | 
			
		||||
		totalErrors += errCount
 | 
			
		||||
		totalDeletes += deleteCount
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// ledger_transactions queries
 | 
			
		||||
	if !c.settings.SkipLedgerTransactionsTable {
 | 
			
		||||
		log.Println("Generating delete queries for ledger_transactions table")
 | 
			
		||||
		info = c.prepareSimpleDeleteQueries(rangeFrom, rangeTo,
 | 
			
		||||
			"DELETE FROM ledger_transactions WHERE ledger_sequence = ?")
 | 
			
		||||
		log.Printf("Total delete queries: %d\n\n", len(info.Data))
 | 
			
		||||
		deleteCount, errCount = c.performDeleteQueries(&info, columnSettings{UseBlob: false, UseSeq: true})
 | 
			
		||||
		totalErrors += errCount
 | 
			
		||||
		totalDeletes += deleteCount
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// ledgers queries
 | 
			
		||||
	if !c.settings.SkipLedgersTable {
 | 
			
		||||
		log.Println("Generating delete queries for ledgers table")
 | 
			
		||||
 | 
			
		||||
		info = c.prepareSimpleDeleteQueries(rangeFrom, rangeTo,
 | 
			
		||||
			"DELETE FROM ledgers WHERE sequence = ?")
 | 
			
		||||
		log.Printf("Total delete queries: %d\n\n", len(info.Data))
 | 
			
		||||
		deleteCount, errCount = c.performDeleteQueries(&info, columnSettings{UseBlob: false, UseSeq: true})
 | 
			
		||||
		totalErrors += errCount
 | 
			
		||||
		totalDeletes += deleteCount
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// TODO: tbd what to do with account_tx as it got tuple for seq_idx
 | 
			
		||||
	// TODO: also, whether we need to take care of nft tables and other stuff like that
 | 
			
		||||
 | 
			
		||||
	if !c.settings.SkipWriteLatestLedger {
 | 
			
		||||
		var (
 | 
			
		||||
			first maybe.Maybe[uint64]
 | 
			
		||||
			last  maybe.Maybe[uint64]
 | 
			
		||||
		)
 | 
			
		||||
 | 
			
		||||
		if fromLedgerIdx.HasValue() {
 | 
			
		||||
			last = maybe.Set(fromLedgerIdx.Value() - 1)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if toLedgerIdx.HasValue() {
 | 
			
		||||
			first = maybe.Set(toLedgerIdx.Value() + 1)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if err := c.updateLedgerRange(first, last); err != nil {
 | 
			
		||||
			log.Printf("ERROR failed updating ledger range: %s\n", err)
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	log.Printf("TOTAL ERRORS: %d\n", totalErrors)
 | 
			
		||||
	log.Printf("TOTAL ROWS TRAVERSED: %d\n", totalRows)
 | 
			
		||||
	log.Printf("TOTAL DELETES: %d\n\n", totalDeletes)
 | 
			
		||||
 | 
			
		||||
	log.Printf("Completed deletion for %s -> %s\n\n", fromStr, toStr)
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *ClioCass) prepareSimpleDeleteQueries(
 | 
			
		||||
	fromLedgerIdx uint64,
 | 
			
		||||
	toLedgerIdx uint64,
 | 
			
		||||
	deleteQueryTemplate string,
 | 
			
		||||
) deleteInfo {
 | 
			
		||||
	var info = deleteInfo{Query: deleteQueryTemplate}
 | 
			
		||||
 | 
			
		||||
	for i := fromLedgerIdx; i <= toLedgerIdx; i++ {
 | 
			
		||||
		info.Data = append(info.Data, deleteParams{Seq: i})
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return info
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *ClioCass) prepareDeleteQueries(
 | 
			
		||||
	fromLedgerIdx maybe.Maybe[uint64],
 | 
			
		||||
	toLedgerIdx maybe.Maybe[uint64],
 | 
			
		||||
	queryTemplate string,
 | 
			
		||||
	deleteQueryTemplate string,
 | 
			
		||||
	keepLastValid bool,
 | 
			
		||||
) (deleteInfo, uint64, uint64) {
 | 
			
		||||
	rangesChannel := make(chan *util.TokenRange, len(c.settings.Ranges))
 | 
			
		||||
	for i := range c.settings.Ranges {
 | 
			
		||||
		rangesChannel <- c.settings.Ranges[i]
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	close(rangesChannel)
 | 
			
		||||
 | 
			
		||||
	outChannel := make(chan deleteParams)
 | 
			
		||||
	var info = deleteInfo{Query: deleteQueryTemplate}
 | 
			
		||||
 | 
			
		||||
	go func() {
 | 
			
		||||
		total := uint64(0)
 | 
			
		||||
		for params := range outChannel {
 | 
			
		||||
			total += 1
 | 
			
		||||
			if total%1000 == 0 {
 | 
			
		||||
				log.Printf("... %d queries ...\n", total)
 | 
			
		||||
			}
 | 
			
		||||
			info.Data = append(info.Data, params)
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	var wg sync.WaitGroup
 | 
			
		||||
	var sessionCreationWaitGroup sync.WaitGroup
 | 
			
		||||
	var totalRows uint64
 | 
			
		||||
	var totalErrors uint64
 | 
			
		||||
 | 
			
		||||
	wg.Add(c.settings.WorkerCount)
 | 
			
		||||
	sessionCreationWaitGroup.Add(c.settings.WorkerCount)
 | 
			
		||||
 | 
			
		||||
	for i := 0; i < c.settings.WorkerCount; i++ {
 | 
			
		||||
		go func(q string) {
 | 
			
		||||
			defer wg.Done()
 | 
			
		||||
 | 
			
		||||
			var session *gocql.Session
 | 
			
		||||
			var err error
 | 
			
		||||
			if session, err = c.clusterConfig.CreateSession(); err == nil {
 | 
			
		||||
				defer session.Close()
 | 
			
		||||
 | 
			
		||||
				sessionCreationWaitGroup.Done()
 | 
			
		||||
				sessionCreationWaitGroup.Wait()
 | 
			
		||||
				preparedQuery := session.Query(q)
 | 
			
		||||
 | 
			
		||||
				for r := range rangesChannel {
 | 
			
		||||
					preparedQuery.Bind(r.StartRange, r.EndRange)
 | 
			
		||||
 | 
			
		||||
					var pageState []byte
 | 
			
		||||
					var rowsRetrieved uint64
 | 
			
		||||
 | 
			
		||||
					var previousKey []byte
 | 
			
		||||
					var foundLastValid bool
 | 
			
		||||
 | 
			
		||||
					for {
 | 
			
		||||
						iter := preparedQuery.PageSize(c.clusterConfig.PageSize).PageState(pageState).Iter()
 | 
			
		||||
						nextPageState := iter.PageState()
 | 
			
		||||
						scanner := iter.Scanner()
 | 
			
		||||
 | 
			
		||||
						for scanner.Next() {
 | 
			
		||||
							var key []byte
 | 
			
		||||
							var seq uint64
 | 
			
		||||
 | 
			
		||||
							err = scanner.Scan(&key, &seq)
 | 
			
		||||
							if err == nil {
 | 
			
		||||
								rowsRetrieved++
 | 
			
		||||
 | 
			
		||||
								if keepLastValid && !slices.Equal(previousKey, key) {
 | 
			
		||||
									previousKey = key
 | 
			
		||||
									foundLastValid = false
 | 
			
		||||
								}
 | 
			
		||||
 | 
			
		||||
								// only grab the rows that are in the correct range of sequence numbers
 | 
			
		||||
								if fromLedgerIdx.HasValue() && fromLedgerIdx.Value() <= seq {
 | 
			
		||||
									outChannel <- deleteParams{Seq: seq, Blob: key}
 | 
			
		||||
								} else if toLedgerIdx.HasValue() {
 | 
			
		||||
									if seq < toLedgerIdx.Value() && (!keepLastValid || foundLastValid) {
 | 
			
		||||
										outChannel <- deleteParams{Seq: seq, Blob: key}
 | 
			
		||||
									} else if seq <= toLedgerIdx.Value()+1 {
 | 
			
		||||
										foundLastValid = true
 | 
			
		||||
									}
 | 
			
		||||
								}
 | 
			
		||||
							} else {
 | 
			
		||||
								log.Printf("ERROR: page iteration failed: %s\n", err)
 | 
			
		||||
								fmt.Fprintf(os.Stderr, "FAILED QUERY: %s\n", fmt.Sprintf("%s [from=%d][to=%d][pagestate=%x]", queryTemplate, r.StartRange, r.EndRange, pageState))
 | 
			
		||||
								atomic.AddUint64(&totalErrors, 1)
 | 
			
		||||
							}
 | 
			
		||||
						}
 | 
			
		||||
 | 
			
		||||
						if len(nextPageState) == 0 {
 | 
			
		||||
							break
 | 
			
		||||
						}
 | 
			
		||||
 | 
			
		||||
						pageState = nextPageState
 | 
			
		||||
					}
 | 
			
		||||
 | 
			
		||||
					atomic.AddUint64(&totalRows, rowsRetrieved)
 | 
			
		||||
				}
 | 
			
		||||
			} else {
 | 
			
		||||
				log.Printf("ERROR: %s\n", err)
 | 
			
		||||
				fmt.Fprintf(os.Stderr, "FAILED TO CREATE SESSION: %s\n", err)
 | 
			
		||||
				atomic.AddUint64(&totalErrors, 1)
 | 
			
		||||
			}
 | 
			
		||||
		}(queryTemplate)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	wg.Wait()
 | 
			
		||||
	close(outChannel)
 | 
			
		||||
 | 
			
		||||
	return info, totalRows, totalErrors
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *ClioCass) splitDeleteWork(info *deleteInfo) [][]deleteParams {
 | 
			
		||||
	var n = c.settings.WorkerCount
 | 
			
		||||
	var chunkSize = len(info.Data) / n
 | 
			
		||||
	var chunks [][]deleteParams
 | 
			
		||||
 | 
			
		||||
	if len(info.Data) == 0 {
 | 
			
		||||
		return chunks
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if chunkSize < 1 {
 | 
			
		||||
		chunks = append(chunks, info.Data)
 | 
			
		||||
		return chunks
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for i := 0; i < len(info.Data); i += chunkSize {
 | 
			
		||||
		end := i + chunkSize
 | 
			
		||||
 | 
			
		||||
		if end > len(info.Data) {
 | 
			
		||||
			end = len(info.Data)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		chunks = append(chunks, info.Data[i:end])
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return chunks
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *ClioCass) performDeleteQueries(info *deleteInfo, colSettings columnSettings) (uint64, uint64) {
 | 
			
		||||
	var wg sync.WaitGroup
 | 
			
		||||
	var sessionCreationWaitGroup sync.WaitGroup
 | 
			
		||||
	var totalDeletes uint64
 | 
			
		||||
	var totalErrors uint64
 | 
			
		||||
 | 
			
		||||
	chunks := c.splitDeleteWork(info)
 | 
			
		||||
	chunksChannel := make(chan []deleteParams, len(chunks))
 | 
			
		||||
	for i := range chunks {
 | 
			
		||||
		chunksChannel <- chunks[i]
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	close(chunksChannel)
 | 
			
		||||
 | 
			
		||||
	wg.Add(c.settings.WorkerCount)
 | 
			
		||||
	sessionCreationWaitGroup.Add(c.settings.WorkerCount)
 | 
			
		||||
 | 
			
		||||
	query := info.Query
 | 
			
		||||
	bindCount := strings.Count(query, "?")
 | 
			
		||||
 | 
			
		||||
	for i := 0; i < c.settings.WorkerCount; i++ {
 | 
			
		||||
		go func(number int, q string, bc int) {
 | 
			
		||||
			defer wg.Done()
 | 
			
		||||
 | 
			
		||||
			var session *gocql.Session
 | 
			
		||||
			var err error
 | 
			
		||||
			if session, err = c.clusterConfig.CreateSession(); err == nil {
 | 
			
		||||
				defer session.Close()
 | 
			
		||||
 | 
			
		||||
				sessionCreationWaitGroup.Done()
 | 
			
		||||
				sessionCreationWaitGroup.Wait()
 | 
			
		||||
				preparedQuery := session.Query(q)
 | 
			
		||||
 | 
			
		||||
				for chunk := range chunksChannel {
 | 
			
		||||
					for _, r := range chunk {
 | 
			
		||||
						if bc == 2 {
 | 
			
		||||
							preparedQuery.Bind(r.Blob, r.Seq)
 | 
			
		||||
						} else if bc == 1 {
 | 
			
		||||
							if colSettings.UseSeq {
 | 
			
		||||
								preparedQuery.Bind(r.Seq)
 | 
			
		||||
							} else if colSettings.UseBlob {
 | 
			
		||||
								preparedQuery.Bind(r.Blob)
 | 
			
		||||
							}
 | 
			
		||||
						}
 | 
			
		||||
 | 
			
		||||
						if err := preparedQuery.Exec(); err != nil {
 | 
			
		||||
							log.Printf("DELETE ERROR: %s\n", err)
 | 
			
		||||
							fmt.Fprintf(os.Stderr, "FAILED QUERY: %s\n", fmt.Sprintf("%s [blob=0x%x][seq=%d]", info.Query, r.Blob, r.Seq))
 | 
			
		||||
							atomic.AddUint64(&totalErrors, 1)
 | 
			
		||||
						} else {
 | 
			
		||||
							atomic.AddUint64(&totalDeletes, 1)
 | 
			
		||||
							if atomic.LoadUint64(&totalDeletes)%10000 == 0 {
 | 
			
		||||
								log.Printf("... %d deletes ...\n", totalDeletes)
 | 
			
		||||
							}
 | 
			
		||||
						}
 | 
			
		||||
					}
 | 
			
		||||
				}
 | 
			
		||||
			} else {
 | 
			
		||||
				log.Printf("ERROR: %s\n", err)
 | 
			
		||||
				fmt.Fprintf(os.Stderr, "FAILED TO CREATE SESSION: %s\n", err)
 | 
			
		||||
				atomic.AddUint64(&totalErrors, 1)
 | 
			
		||||
			}
 | 
			
		||||
		}(i, query, bindCount)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	wg.Wait()
 | 
			
		||||
	return totalDeletes, totalErrors
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *ClioCass) updateLedgerRange(newStartLedger maybe.Maybe[uint64], newEndLedger maybe.Maybe[uint64]) error {
 | 
			
		||||
	if session, err := c.clusterConfig.CreateSession(); err == nil {
 | 
			
		||||
		defer session.Close()
 | 
			
		||||
 | 
			
		||||
		query := "UPDATE ledger_range SET sequence = ? WHERE is_latest = ?"
 | 
			
		||||
 | 
			
		||||
		if newEndLedger.HasValue() {
 | 
			
		||||
			log.Printf("Updating ledger range end to %d\n", newEndLedger.Value())
 | 
			
		||||
 | 
			
		||||
			preparedQuery := session.Query(query, newEndLedger.Value(), true)
 | 
			
		||||
			if err := preparedQuery.Exec(); err != nil {
 | 
			
		||||
				fmt.Fprintf(os.Stderr, "FAILED QUERY: %s [seq=%d][true]\n", query, newEndLedger.Value())
 | 
			
		||||
				return err
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if newStartLedger.HasValue() {
 | 
			
		||||
			log.Printf("Updating ledger range start to %d\n", newStartLedger.Value())
 | 
			
		||||
 | 
			
		||||
			preparedQuery := session.Query(query, newStartLedger.Value(), false)
 | 
			
		||||
			if err := preparedQuery.Exec(); err != nil {
 | 
			
		||||
				fmt.Fprintf(os.Stderr, "FAILED QUERY: %s [seq=%d][false]\n", query, newStartLedger.Value())
 | 
			
		||||
				return err
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	} else {
 | 
			
		||||
		fmt.Fprintf(os.Stderr, "FAILED TO CREATE SESSION: %s\n", err)
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										91
									
								
								tools/cassandra_delete_range/internal/util/util.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										91
									
								
								tools/cassandra_delete_range/internal/util/util.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,91 @@
 | 
			
		||||
package util
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"log"
 | 
			
		||||
	"math"
 | 
			
		||||
	"math/rand"
 | 
			
		||||
 | 
			
		||||
	"github.com/gocql/gocql"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type TokenRange struct {
 | 
			
		||||
	StartRange int64
 | 
			
		||||
	EndRange   int64
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func Shuffle(data []*TokenRange) {
 | 
			
		||||
	for i := 1; i < len(data); i++ {
 | 
			
		||||
		r := rand.Intn(i + 1)
 | 
			
		||||
		if i != r {
 | 
			
		||||
			data[r], data[i] = data[i], data[r]
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func PromptContinue() bool {
 | 
			
		||||
	var continueFlag string
 | 
			
		||||
 | 
			
		||||
	log.Println("Are you sure you want to continue? (y/n)")
 | 
			
		||||
	if fmt.Scanln(&continueFlag); continueFlag != "y" {
 | 
			
		||||
		return false
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return true
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GetTokenRanges(workerCount int) []*TokenRange {
 | 
			
		||||
	var n = workerCount
 | 
			
		||||
	var m = int64(n * 100)
 | 
			
		||||
	var maxSize uint64 = math.MaxInt64 * 2
 | 
			
		||||
	var rangeSize = maxSize / uint64(m)
 | 
			
		||||
 | 
			
		||||
	var start int64 = math.MinInt64
 | 
			
		||||
	var end int64
 | 
			
		||||
	var shouldBreak = false
 | 
			
		||||
 | 
			
		||||
	var ranges = make([]*TokenRange, m)
 | 
			
		||||
 | 
			
		||||
	for i := int64(0); i < m; i++ {
 | 
			
		||||
		end = start + int64(rangeSize)
 | 
			
		||||
		if start > 0 && end < 0 {
 | 
			
		||||
			end = math.MaxInt64
 | 
			
		||||
			shouldBreak = true
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		ranges[i] = &TokenRange{StartRange: start, EndRange: end}
 | 
			
		||||
 | 
			
		||||
		if shouldBreak {
 | 
			
		||||
			break
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		start = end + 1
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return ranges
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GetConsistencyLevel(consistencyValue string) gocql.Consistency {
 | 
			
		||||
	switch consistencyValue {
 | 
			
		||||
	case "any":
 | 
			
		||||
		return gocql.Any
 | 
			
		||||
	case "one":
 | 
			
		||||
		return gocql.One
 | 
			
		||||
	case "two":
 | 
			
		||||
		return gocql.Two
 | 
			
		||||
	case "three":
 | 
			
		||||
		return gocql.Three
 | 
			
		||||
	case "quorum":
 | 
			
		||||
		return gocql.Quorum
 | 
			
		||||
	case "all":
 | 
			
		||||
		return gocql.All
 | 
			
		||||
	case "localquorum":
 | 
			
		||||
		return gocql.LocalQuorum
 | 
			
		||||
	case "eachquorum":
 | 
			
		||||
		return gocql.EachQuorum
 | 
			
		||||
	case "localone":
 | 
			
		||||
		return gocql.LocalOne
 | 
			
		||||
	default:
 | 
			
		||||
		return gocql.One
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
@@ -1,619 +0,0 @@
 | 
			
		||||
//
 | 
			
		||||
// Based off of https://github.com/scylladb/scylla-code-samples/blob/master/efficient_full_table_scan_example_code/efficient_full_table_scan.go
 | 
			
		||||
//
 | 
			
		||||
 | 
			
		||||
package main
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"log"
 | 
			
		||||
	"math"
 | 
			
		||||
	"math/rand"
 | 
			
		||||
	"os"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"sync"
 | 
			
		||||
	"sync/atomic"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/alecthomas/kingpin/v2"
 | 
			
		||||
	"github.com/gocql/gocql"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	defaultNumberOfNodesInCluster = 3
 | 
			
		||||
	defaultNumberOfCoresInNode    = 8
 | 
			
		||||
	defaultSmudgeFactor           = 3
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var (
 | 
			
		||||
	clusterHosts      = kingpin.Arg("hosts", "Your Scylla nodes IP addresses, comma separated (i.e. 192.168.1.1,192.168.1.2,192.168.1.3)").Required().String()
 | 
			
		||||
	earliestLedgerIdx = kingpin.Flag("ledgerIdx", "Sets the earliest ledger_index to keep untouched").Short('i').Required().Uint64()
 | 
			
		||||
 | 
			
		||||
	nodesInCluster        = kingpin.Flag("nodes-in-cluster", "Number of nodes in your Scylla cluster").Short('n').Default(fmt.Sprintf("%d", defaultNumberOfNodesInCluster)).Int()
 | 
			
		||||
	coresInNode           = kingpin.Flag("cores-in-node", "Number of cores in each node").Short('c').Default(fmt.Sprintf("%d", defaultNumberOfCoresInNode)).Int()
 | 
			
		||||
	smudgeFactor          = kingpin.Flag("smudge-factor", "Yet another factor to make parallelism cooler").Short('s').Default(fmt.Sprintf("%d", defaultSmudgeFactor)).Int()
 | 
			
		||||
	clusterConsistency    = kingpin.Flag("consistency", "Cluster consistency level. Use 'localone' for multi DC").Short('o').Default("localquorum").String()
 | 
			
		||||
	clusterTimeout        = kingpin.Flag("timeout", "Maximum duration for query execution in millisecond").Short('t').Default("15000").Int()
 | 
			
		||||
	clusterNumConnections = kingpin.Flag("cluster-number-of-connections", "Number of connections per host per session (in our case, per thread)").Short('b').Default("1").Int()
 | 
			
		||||
	clusterCQLVersion     = kingpin.Flag("cql-version", "The CQL version to use").Short('l').Default("3.0.0").String()
 | 
			
		||||
	clusterPageSize       = kingpin.Flag("cluster-page-size", "Page size of results").Short('p').Default("5000").Int()
 | 
			
		||||
	keyspace              = kingpin.Flag("keyspace", "Keyspace to use").Short('k').Default("clio_fh").String()
 | 
			
		||||
 | 
			
		||||
	userName = kingpin.Flag("username", "Username to use when connecting to the cluster").String()
 | 
			
		||||
	password = kingpin.Flag("password", "Password to use when connecting to the cluster").String()
 | 
			
		||||
 | 
			
		||||
	skipSuccessorTable          = kingpin.Flag("skip-successor", "Whether to skip deletion from successor table").Default("false").Bool()
 | 
			
		||||
	skipObjectsTable            = kingpin.Flag("skip-objects", "Whether to skip deletion from objects table").Default("false").Bool()
 | 
			
		||||
	skipLedgerHashesTable       = kingpin.Flag("skip-ledger-hashes", "Whether to skip deletion from ledger_hashes table").Default("false").Bool()
 | 
			
		||||
	skipTransactionsTable       = kingpin.Flag("skip-transactions", "Whether to skip deletion from transactions table").Default("false").Bool()
 | 
			
		||||
	skipDiffTable               = kingpin.Flag("skip-diff", "Whether to skip deletion from diff table").Default("false").Bool()
 | 
			
		||||
	skipLedgerTransactionsTable = kingpin.Flag("skip-ledger-transactions", "Whether to skip deletion from ledger_transactions table").Default("false").Bool()
 | 
			
		||||
	skipLedgersTable            = kingpin.Flag("skip-ledgers", "Whether to skip deletion from ledgers table").Default("false").Bool()
 | 
			
		||||
	skipWriteLatestLedger       = kingpin.Flag("skip-write-latest-ledger", "Whether to skip writing the latest ledger index").Default("false").Bool()
 | 
			
		||||
 | 
			
		||||
	workerCount = 1           // the calculated number of parallel goroutines the client should run
 | 
			
		||||
	ranges      []*tokenRange // the calculated ranges to be executed in parallel
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type tokenRange struct {
 | 
			
		||||
	StartRange int64
 | 
			
		||||
	EndRange   int64
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type deleteParams struct {
 | 
			
		||||
	Seq  uint64
 | 
			
		||||
	Blob []byte // hash, key, etc
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type columnSettings struct {
 | 
			
		||||
	UseSeq  bool
 | 
			
		||||
	UseBlob bool
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type deleteInfo struct {
 | 
			
		||||
	Query string
 | 
			
		||||
	Data  []deleteParams
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func getTokenRanges() []*tokenRange {
 | 
			
		||||
	var n = workerCount
 | 
			
		||||
	var m = int64(n * 100)
 | 
			
		||||
	var maxSize uint64 = math.MaxInt64 * 2
 | 
			
		||||
	var rangeSize = maxSize / uint64(m)
 | 
			
		||||
 | 
			
		||||
	var start int64 = math.MinInt64
 | 
			
		||||
	var end int64
 | 
			
		||||
	var shouldBreak = false
 | 
			
		||||
 | 
			
		||||
	var ranges = make([]*tokenRange, m)
 | 
			
		||||
 | 
			
		||||
	for i := int64(0); i < m; i++ {
 | 
			
		||||
		end = start + int64(rangeSize)
 | 
			
		||||
		if start > 0 && end < 0 {
 | 
			
		||||
			end = math.MaxInt64
 | 
			
		||||
			shouldBreak = true
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		ranges[i] = &tokenRange{StartRange: start, EndRange: end}
 | 
			
		||||
 | 
			
		||||
		if shouldBreak {
 | 
			
		||||
			break
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		start = end + 1
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return ranges
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func splitDeleteWork(info *deleteInfo) [][]deleteParams {
 | 
			
		||||
	var n = workerCount
 | 
			
		||||
	var chunkSize = len(info.Data) / n
 | 
			
		||||
	var chunks [][]deleteParams
 | 
			
		||||
 | 
			
		||||
	if len(info.Data) == 0 {
 | 
			
		||||
		return chunks
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if chunkSize < 1 {
 | 
			
		||||
		chunks = append(chunks, info.Data)
 | 
			
		||||
		return chunks
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for i := 0; i < len(info.Data); i += chunkSize {
 | 
			
		||||
		end := i + chunkSize
 | 
			
		||||
 | 
			
		||||
		if end > len(info.Data) {
 | 
			
		||||
			end = len(info.Data)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		chunks = append(chunks, info.Data[i:end])
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return chunks
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func shuffle(data []*tokenRange) {
 | 
			
		||||
	for i := 1; i < len(data); i++ {
 | 
			
		||||
		r := rand.Intn(i + 1)
 | 
			
		||||
		if i != r {
 | 
			
		||||
			data[r], data[i] = data[i], data[r]
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func getConsistencyLevel(consistencyValue string) gocql.Consistency {
 | 
			
		||||
	switch consistencyValue {
 | 
			
		||||
	case "any":
 | 
			
		||||
		return gocql.Any
 | 
			
		||||
	case "one":
 | 
			
		||||
		return gocql.One
 | 
			
		||||
	case "two":
 | 
			
		||||
		return gocql.Two
 | 
			
		||||
	case "three":
 | 
			
		||||
		return gocql.Three
 | 
			
		||||
	case "quorum":
 | 
			
		||||
		return gocql.Quorum
 | 
			
		||||
	case "all":
 | 
			
		||||
		return gocql.All
 | 
			
		||||
	case "localquorum":
 | 
			
		||||
		return gocql.LocalQuorum
 | 
			
		||||
	case "eachquorum":
 | 
			
		||||
		return gocql.EachQuorum
 | 
			
		||||
	case "localone":
 | 
			
		||||
		return gocql.LocalOne
 | 
			
		||||
	default:
 | 
			
		||||
		return gocql.One
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func main() {
 | 
			
		||||
	log.SetOutput(os.Stdout)
 | 
			
		||||
	kingpin.Parse()
 | 
			
		||||
 | 
			
		||||
	workerCount = (*nodesInCluster) * (*coresInNode) * (*smudgeFactor)
 | 
			
		||||
	ranges = getTokenRanges()
 | 
			
		||||
	shuffle(ranges)
 | 
			
		||||
 | 
			
		||||
	hosts := strings.Split(*clusterHosts, ",")
 | 
			
		||||
 | 
			
		||||
	cluster := gocql.NewCluster(hosts...)
 | 
			
		||||
	cluster.Consistency = getConsistencyLevel(*clusterConsistency)
 | 
			
		||||
	cluster.Timeout = time.Duration(*clusterTimeout * 1000 * 1000)
 | 
			
		||||
	cluster.NumConns = *clusterNumConnections
 | 
			
		||||
	cluster.CQLVersion = *clusterCQLVersion
 | 
			
		||||
	cluster.PageSize = *clusterPageSize
 | 
			
		||||
	cluster.Keyspace = *keyspace
 | 
			
		||||
 | 
			
		||||
	if *userName != "" {
 | 
			
		||||
		cluster.Authenticator = gocql.PasswordAuthenticator{
 | 
			
		||||
			Username: *userName,
 | 
			
		||||
			Password: *password,
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if *earliestLedgerIdx == 0 {
 | 
			
		||||
		log.Println("Please specify ledger index to delete from")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	runParameters := fmt.Sprintf(`
 | 
			
		||||
Execution Parameters:
 | 
			
		||||
=====================
 | 
			
		||||
 | 
			
		||||
Range to be deleted           : %d -> latest
 | 
			
		||||
Scylla cluster nodes          : %s
 | 
			
		||||
Keyspace                      : %s
 | 
			
		||||
Consistency                   : %s
 | 
			
		||||
Timeout (ms)                  : %d
 | 
			
		||||
Connections per host          : %d
 | 
			
		||||
CQL Version                   : %s
 | 
			
		||||
Page size                     : %d
 | 
			
		||||
# of parallel threads         : %d
 | 
			
		||||
# of ranges to be executed    : %d
 | 
			
		||||
 | 
			
		||||
Skip deletion of:
 | 
			
		||||
- successor table             : %t
 | 
			
		||||
- objects table               : %t
 | 
			
		||||
- ledger_hashes table         : %t
 | 
			
		||||
- transactions table          : %t
 | 
			
		||||
- diff table                  : %t
 | 
			
		||||
- ledger_transactions table   : %t
 | 
			
		||||
- ledgers table               : %t
 | 
			
		||||
 | 
			
		||||
Will rite latest ledger       : %t
 | 
			
		||||
 | 
			
		||||
`,
 | 
			
		||||
		*earliestLedgerIdx,
 | 
			
		||||
		*clusterHosts,
 | 
			
		||||
		*keyspace,
 | 
			
		||||
		*clusterConsistency,
 | 
			
		||||
		cluster.Timeout/1000/1000,
 | 
			
		||||
		*clusterNumConnections,
 | 
			
		||||
		*clusterCQLVersion,
 | 
			
		||||
		*clusterPageSize,
 | 
			
		||||
		workerCount,
 | 
			
		||||
		len(ranges),
 | 
			
		||||
		*skipSuccessorTable,
 | 
			
		||||
		*skipObjectsTable,
 | 
			
		||||
		*skipLedgerHashesTable,
 | 
			
		||||
		*skipTransactionsTable,
 | 
			
		||||
		*skipDiffTable,
 | 
			
		||||
		*skipLedgerTransactionsTable,
 | 
			
		||||
		*skipLedgersTable,
 | 
			
		||||
		!*skipWriteLatestLedger)
 | 
			
		||||
 | 
			
		||||
	fmt.Println(runParameters)
 | 
			
		||||
 | 
			
		||||
	log.Printf("Will delete everything after ledger index %d (exclusive) and till latest\n", *earliestLedgerIdx)
 | 
			
		||||
	log.Println("WARNING: Please make sure that there are no Clio writers operating on the DB while this script is running")
 | 
			
		||||
	log.Println("Are you sure you want to continue? (y/n)")
 | 
			
		||||
 | 
			
		||||
	var continueFlag string
 | 
			
		||||
	if fmt.Scanln(&continueFlag); continueFlag != "y" {
 | 
			
		||||
		log.Println("Aborting...")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	startTime := time.Now().UTC()
 | 
			
		||||
 | 
			
		||||
	earliestLedgerIdxInDB, latestLedgerIdxInDB, err := getLedgerRange(cluster)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		log.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if earliestLedgerIdxInDB > *earliestLedgerIdx {
 | 
			
		||||
		log.Fatal("Earliest ledger index in DB is greater than the one specified. Aborting...")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if latestLedgerIdxInDB < *earliestLedgerIdx {
 | 
			
		||||
		log.Fatal("Latest ledger index in DB is smaller than the one specified. Aborting...")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err := deleteLedgerData(cluster, *earliestLedgerIdx+1, latestLedgerIdxInDB); err != nil {
 | 
			
		||||
		log.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	fmt.Printf("Total Execution Time: %s\n\n", time.Since(startTime))
 | 
			
		||||
	fmt.Println("NOTE: Cassandra/ScyllaDB only writes tombstones. You need to run compaction to free up disk space.")
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func getLedgerRange(cluster *gocql.ClusterConfig) (uint64, uint64, error) {
 | 
			
		||||
	var (
 | 
			
		||||
		firstLedgerIdx  uint64
 | 
			
		||||
		latestLedgerIdx uint64
 | 
			
		||||
	)
 | 
			
		||||
 | 
			
		||||
	session, err := cluster.CreateSession()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		log.Fatal(err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	defer session.Close()
 | 
			
		||||
 | 
			
		||||
	if err := session.Query("select sequence from ledger_range where is_latest = ?", false).Scan(&firstLedgerIdx); err != nil {
 | 
			
		||||
		return 0, 0, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err := session.Query("select sequence from ledger_range where is_latest = ?", true).Scan(&latestLedgerIdx); err != nil {
 | 
			
		||||
		return 0, 0, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	log.Printf("DB ledger range is %d:%d\n", firstLedgerIdx, latestLedgerIdx)
 | 
			
		||||
	return firstLedgerIdx, latestLedgerIdx, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func deleteLedgerData(cluster *gocql.ClusterConfig, fromLedgerIdx uint64, toLedgerIdx uint64) error {
 | 
			
		||||
	var totalErrors uint64
 | 
			
		||||
	var totalRows uint64
 | 
			
		||||
	var totalDeletes uint64
 | 
			
		||||
 | 
			
		||||
	var info deleteInfo
 | 
			
		||||
	var rowsCount uint64
 | 
			
		||||
	var deleteCount uint64
 | 
			
		||||
	var errCount uint64
 | 
			
		||||
 | 
			
		||||
	log.Printf("Start scanning and removing data for %d -> latest (%d according to ledger_range table)\n\n", fromLedgerIdx, toLedgerIdx)
 | 
			
		||||
 | 
			
		||||
	// successor queries
 | 
			
		||||
	if !*skipSuccessorTable {
 | 
			
		||||
		log.Println("Generating delete queries for successor table")
 | 
			
		||||
		info, rowsCount, errCount = prepareDeleteQueries(cluster, fromLedgerIdx,
 | 
			
		||||
			"SELECT key, seq FROM successor WHERE token(key) >= ? AND token(key) <= ?",
 | 
			
		||||
			"DELETE FROM successor WHERE key = ? AND seq = ?")
 | 
			
		||||
		log.Printf("Total delete queries: %d\n", len(info.Data))
 | 
			
		||||
		log.Printf("Total traversed rows: %d\n\n", rowsCount)
 | 
			
		||||
		totalErrors += errCount
 | 
			
		||||
		totalRows += rowsCount
 | 
			
		||||
		deleteCount, errCount = performDeleteQueries(cluster, &info, columnSettings{UseBlob: true, UseSeq: true})
 | 
			
		||||
		totalErrors += errCount
 | 
			
		||||
		totalDeletes += deleteCount
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// objects queries
 | 
			
		||||
	if !*skipObjectsTable {
 | 
			
		||||
		log.Println("Generating delete queries for objects table")
 | 
			
		||||
		info, rowsCount, errCount = prepareDeleteQueries(cluster, fromLedgerIdx,
 | 
			
		||||
			"SELECT key, sequence FROM objects WHERE token(key) >= ? AND token(key) <= ?",
 | 
			
		||||
			"DELETE FROM objects WHERE key = ? AND sequence = ?")
 | 
			
		||||
		log.Printf("Total delete queries: %d\n", len(info.Data))
 | 
			
		||||
		log.Printf("Total traversed rows: %d\n\n", rowsCount)
 | 
			
		||||
		totalErrors += errCount
 | 
			
		||||
		totalRows += rowsCount
 | 
			
		||||
		deleteCount, errCount = performDeleteQueries(cluster, &info, columnSettings{UseBlob: true, UseSeq: true})
 | 
			
		||||
		totalErrors += errCount
 | 
			
		||||
		totalDeletes += deleteCount
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// ledger_hashes queries
 | 
			
		||||
	if !*skipLedgerHashesTable {
 | 
			
		||||
		log.Println("Generating delete queries for ledger_hashes table")
 | 
			
		||||
		info, rowsCount, errCount = prepareDeleteQueries(cluster, fromLedgerIdx,
 | 
			
		||||
			"SELECT hash, sequence FROM ledger_hashes WHERE token(hash) >= ? AND token(hash) <= ?",
 | 
			
		||||
			"DELETE FROM ledger_hashes WHERE hash = ?")
 | 
			
		||||
		log.Printf("Total delete queries: %d\n", len(info.Data))
 | 
			
		||||
		log.Printf("Total traversed rows: %d\n\n", rowsCount)
 | 
			
		||||
		totalErrors += errCount
 | 
			
		||||
		totalRows += rowsCount
 | 
			
		||||
		deleteCount, errCount = performDeleteQueries(cluster, &info, columnSettings{UseBlob: true, UseSeq: false})
 | 
			
		||||
		totalErrors += errCount
 | 
			
		||||
		totalDeletes += deleteCount
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// transactions queries
 | 
			
		||||
	if !*skipTransactionsTable {
 | 
			
		||||
		log.Println("Generating delete queries for transactions table")
 | 
			
		||||
		info, rowsCount, errCount = prepareDeleteQueries(cluster, fromLedgerIdx,
 | 
			
		||||
			"SELECT hash, ledger_sequence FROM transactions WHERE token(hash) >= ? AND token(hash) <= ?",
 | 
			
		||||
			"DELETE FROM transactions WHERE hash = ?")
 | 
			
		||||
		log.Printf("Total delete queries: %d\n", len(info.Data))
 | 
			
		||||
		log.Printf("Total traversed rows: %d\n\n", rowsCount)
 | 
			
		||||
		totalErrors += errCount
 | 
			
		||||
		totalRows += rowsCount
 | 
			
		||||
		deleteCount, errCount = performDeleteQueries(cluster, &info, columnSettings{UseBlob: true, UseSeq: false})
 | 
			
		||||
		totalErrors += errCount
 | 
			
		||||
		totalDeletes += deleteCount
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// diff queries
 | 
			
		||||
	if !*skipDiffTable {
 | 
			
		||||
		log.Println("Generating delete queries for diff table")
 | 
			
		||||
		info = prepareSimpleDeleteQueries(fromLedgerIdx, toLedgerIdx,
 | 
			
		||||
			"DELETE FROM diff WHERE seq = ?")
 | 
			
		||||
		log.Printf("Total delete queries: %d\n\n", len(info.Data))
 | 
			
		||||
		deleteCount, errCount = performDeleteQueries(cluster, &info, columnSettings{UseBlob: true, UseSeq: true})
 | 
			
		||||
		totalErrors += errCount
 | 
			
		||||
		totalDeletes += deleteCount
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// ledger_transactions queries
 | 
			
		||||
	if !*skipLedgerTransactionsTable {
 | 
			
		||||
		log.Println("Generating delete queries for ledger_transactions table")
 | 
			
		||||
		info = prepareSimpleDeleteQueries(fromLedgerIdx, toLedgerIdx,
 | 
			
		||||
			"DELETE FROM ledger_transactions WHERE ledger_sequence = ?")
 | 
			
		||||
		log.Printf("Total delete queries: %d\n\n", len(info.Data))
 | 
			
		||||
		deleteCount, errCount = performDeleteQueries(cluster, &info, columnSettings{UseBlob: false, UseSeq: true})
 | 
			
		||||
		totalErrors += errCount
 | 
			
		||||
		totalDeletes += deleteCount
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// ledgers queries
 | 
			
		||||
	if !*skipLedgersTable {
 | 
			
		||||
		log.Println("Generating delete queries for ledgers table")
 | 
			
		||||
		info = prepareSimpleDeleteQueries(fromLedgerIdx, toLedgerIdx,
 | 
			
		||||
			"DELETE FROM ledgers WHERE sequence = ?")
 | 
			
		||||
		log.Printf("Total delete queries: %d\n\n", len(info.Data))
 | 
			
		||||
		deleteCount, errCount = performDeleteQueries(cluster, &info, columnSettings{UseBlob: false, UseSeq: true})
 | 
			
		||||
		totalErrors += errCount
 | 
			
		||||
		totalDeletes += deleteCount
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// TODO: tbd what to do with account_tx as it got tuple for seq_idx
 | 
			
		||||
	// TODO: also, whether we need to take care of nft tables and other stuff like that
 | 
			
		||||
 | 
			
		||||
	if !*skipWriteLatestLedger {
 | 
			
		||||
		if err := updateLedgerRange(cluster, fromLedgerIdx-1); err != nil {
 | 
			
		||||
			log.Printf("ERROR failed updating ledger range: %s\n", err)
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		log.Printf("Updated latest ledger to %d in ledger_range table\n\n", fromLedgerIdx-1)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	log.Printf("TOTAL ERRORS: %d\n", totalErrors)
 | 
			
		||||
	log.Printf("TOTAL ROWS TRAVERSED: %d\n", totalRows)
 | 
			
		||||
	log.Printf("TOTAL DELETES: %d\n\n", totalDeletes)
 | 
			
		||||
 | 
			
		||||
	log.Printf("Completed deletion for %d -> %d\n\n", fromLedgerIdx, toLedgerIdx)
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func prepareSimpleDeleteQueries(fromLedgerIdx uint64, toLedgerIdx uint64, deleteQueryTemplate string) deleteInfo {
 | 
			
		||||
	var info = deleteInfo{Query: deleteQueryTemplate}
 | 
			
		||||
 | 
			
		||||
	// Note: we deliberately add 1 extra ledger to make sure we delete any data Clio might have written
 | 
			
		||||
	// if it crashed or was stopped in the middle of writing just before it wrote ledger_range.
 | 
			
		||||
	for i := fromLedgerIdx; i <= toLedgerIdx+1; i++ {
 | 
			
		||||
		info.Data = append(info.Data, deleteParams{Seq: i})
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return info
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func prepareDeleteQueries(cluster *gocql.ClusterConfig, fromLedgerIdx uint64, queryTemplate string, deleteQueryTemplate string) (deleteInfo, uint64, uint64) {
 | 
			
		||||
	rangesChannel := make(chan *tokenRange, len(ranges))
 | 
			
		||||
	for i := range ranges {
 | 
			
		||||
		rangesChannel <- ranges[i]
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	close(rangesChannel)
 | 
			
		||||
 | 
			
		||||
	outChannel := make(chan deleteParams)
 | 
			
		||||
	var info = deleteInfo{Query: deleteQueryTemplate}
 | 
			
		||||
 | 
			
		||||
	go func() {
 | 
			
		||||
		for params := range outChannel {
 | 
			
		||||
			info.Data = append(info.Data, params)
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
 | 
			
		||||
	var wg sync.WaitGroup
 | 
			
		||||
	var sessionCreationWaitGroup sync.WaitGroup
 | 
			
		||||
	var totalRows uint64
 | 
			
		||||
	var totalErrors uint64
 | 
			
		||||
 | 
			
		||||
	wg.Add(workerCount)
 | 
			
		||||
	sessionCreationWaitGroup.Add(workerCount)
 | 
			
		||||
 | 
			
		||||
	for i := 0; i < workerCount; i++ {
 | 
			
		||||
		go func(q string) {
 | 
			
		||||
			defer wg.Done()
 | 
			
		||||
 | 
			
		||||
			var session *gocql.Session
 | 
			
		||||
			var err error
 | 
			
		||||
			if session, err = cluster.CreateSession(); err == nil {
 | 
			
		||||
				defer session.Close()
 | 
			
		||||
 | 
			
		||||
				sessionCreationWaitGroup.Done()
 | 
			
		||||
				sessionCreationWaitGroup.Wait()
 | 
			
		||||
				preparedQuery := session.Query(q)
 | 
			
		||||
 | 
			
		||||
				for r := range rangesChannel {
 | 
			
		||||
					preparedQuery.Bind(r.StartRange, r.EndRange)
 | 
			
		||||
 | 
			
		||||
					var pageState []byte
 | 
			
		||||
					var rowsRetrieved uint64
 | 
			
		||||
 | 
			
		||||
					for {
 | 
			
		||||
						iter := preparedQuery.PageSize(*clusterPageSize).PageState(pageState).Iter()
 | 
			
		||||
						nextPageState := iter.PageState()
 | 
			
		||||
						scanner := iter.Scanner()
 | 
			
		||||
 | 
			
		||||
						for scanner.Next() {
 | 
			
		||||
							var key []byte
 | 
			
		||||
							var seq uint64
 | 
			
		||||
 | 
			
		||||
							err = scanner.Scan(&key, &seq)
 | 
			
		||||
							if err == nil {
 | 
			
		||||
								rowsRetrieved++
 | 
			
		||||
 | 
			
		||||
								// only grab the rows that are in the correct range of sequence numbers
 | 
			
		||||
								if fromLedgerIdx <= seq {
 | 
			
		||||
									outChannel <- deleteParams{Seq: seq, Blob: key}
 | 
			
		||||
								}
 | 
			
		||||
							} else {
 | 
			
		||||
								log.Printf("ERROR: page iteration failed: %s\n", err)
 | 
			
		||||
								fmt.Fprintf(os.Stderr, "FAILED QUERY: %s\n", fmt.Sprintf("%s [from=%d][to=%d][pagestate=%x]", queryTemplate, r.StartRange, r.EndRange, pageState))
 | 
			
		||||
								atomic.AddUint64(&totalErrors, 1)
 | 
			
		||||
							}
 | 
			
		||||
						}
 | 
			
		||||
 | 
			
		||||
						if len(nextPageState) == 0 {
 | 
			
		||||
							break
 | 
			
		||||
						}
 | 
			
		||||
 | 
			
		||||
						pageState = nextPageState
 | 
			
		||||
					}
 | 
			
		||||
 | 
			
		||||
					atomic.AddUint64(&totalRows, rowsRetrieved)
 | 
			
		||||
				}
 | 
			
		||||
			} else {
 | 
			
		||||
				log.Printf("ERROR: %s\n", err)
 | 
			
		||||
				fmt.Fprintf(os.Stderr, "FAILED TO CREATE SESSION: %s\n", err)
 | 
			
		||||
				atomic.AddUint64(&totalErrors, 1)
 | 
			
		||||
			}
 | 
			
		||||
		}(queryTemplate)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	wg.Wait()
 | 
			
		||||
	close(outChannel)
 | 
			
		||||
 | 
			
		||||
	return info, totalRows, totalErrors
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func performDeleteQueries(cluster *gocql.ClusterConfig, info *deleteInfo, colSettings columnSettings) (uint64, uint64) {
 | 
			
		||||
	var wg sync.WaitGroup
 | 
			
		||||
	var sessionCreationWaitGroup sync.WaitGroup
 | 
			
		||||
	var totalDeletes uint64
 | 
			
		||||
	var totalErrors uint64
 | 
			
		||||
 | 
			
		||||
	chunks := splitDeleteWork(info)
 | 
			
		||||
	chunksChannel := make(chan []deleteParams, len(chunks))
 | 
			
		||||
	for i := range chunks {
 | 
			
		||||
		chunksChannel <- chunks[i]
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	close(chunksChannel)
 | 
			
		||||
 | 
			
		||||
	wg.Add(workerCount)
 | 
			
		||||
	sessionCreationWaitGroup.Add(workerCount)
 | 
			
		||||
 | 
			
		||||
	query := info.Query
 | 
			
		||||
	bindCount := strings.Count(query, "?")
 | 
			
		||||
 | 
			
		||||
	for i := 0; i < workerCount; i++ {
 | 
			
		||||
		go func(number int, q string, bc int) {
 | 
			
		||||
			defer wg.Done()
 | 
			
		||||
 | 
			
		||||
			var session *gocql.Session
 | 
			
		||||
			var err error
 | 
			
		||||
			if session, err = cluster.CreateSession(); err == nil {
 | 
			
		||||
				defer session.Close()
 | 
			
		||||
 | 
			
		||||
				sessionCreationWaitGroup.Done()
 | 
			
		||||
				sessionCreationWaitGroup.Wait()
 | 
			
		||||
				preparedQuery := session.Query(q)
 | 
			
		||||
 | 
			
		||||
				for chunk := range chunksChannel {
 | 
			
		||||
					for _, r := range chunk {
 | 
			
		||||
						if bc == 2 {
 | 
			
		||||
							preparedQuery.Bind(r.Blob, r.Seq)
 | 
			
		||||
						} else if bc == 1 {
 | 
			
		||||
							if colSettings.UseSeq {
 | 
			
		||||
								preparedQuery.Bind(r.Seq)
 | 
			
		||||
							} else if colSettings.UseBlob {
 | 
			
		||||
								preparedQuery.Bind(r.Blob)
 | 
			
		||||
							}
 | 
			
		||||
						}
 | 
			
		||||
 | 
			
		||||
						if err := preparedQuery.Exec(); err != nil {
 | 
			
		||||
							log.Printf("DELETE ERROR: %s\n", err)
 | 
			
		||||
							fmt.Fprintf(os.Stderr, "FAILED QUERY: %s\n", fmt.Sprintf("%s [blob=0x%x][seq=%d]", info.Query, r.Blob, r.Seq))
 | 
			
		||||
							atomic.AddUint64(&totalErrors, 1)
 | 
			
		||||
						} else {
 | 
			
		||||
							atomic.AddUint64(&totalDeletes, 1)
 | 
			
		||||
						}
 | 
			
		||||
					}
 | 
			
		||||
				}
 | 
			
		||||
			} else {
 | 
			
		||||
				log.Printf("ERROR: %s\n", err)
 | 
			
		||||
				fmt.Fprintf(os.Stderr, "FAILED TO CREATE SESSION: %s\n", err)
 | 
			
		||||
				atomic.AddUint64(&totalErrors, 1)
 | 
			
		||||
			}
 | 
			
		||||
		}(i, query, bindCount)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	wg.Wait()
 | 
			
		||||
	return totalDeletes, totalErrors
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func updateLedgerRange(cluster *gocql.ClusterConfig, ledgerIndex uint64) error {
 | 
			
		||||
	log.Printf("Updating latest ledger to %d\n", ledgerIndex)
 | 
			
		||||
 | 
			
		||||
	if session, err := cluster.CreateSession(); err == nil {
 | 
			
		||||
		defer session.Close()
 | 
			
		||||
 | 
			
		||||
		query := "UPDATE ledger_range SET sequence = ? WHERE is_latest = ?"
 | 
			
		||||
		preparedQuery := session.Query(query, ledgerIndex, true)
 | 
			
		||||
		if err := preparedQuery.Exec(); err != nil {
 | 
			
		||||
			fmt.Fprintf(os.Stderr, "FAILED QUERY: %s [seq=%d][true]\n", query, ledgerIndex)
 | 
			
		||||
			return err
 | 
			
		||||
		}
 | 
			
		||||
	} else {
 | 
			
		||||
		fmt.Fprintf(os.Stderr, "FAILED TO CREATE SESSION: %s\n", err)
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
		Reference in New Issue
	
	Block a user