diff --git a/CMake/deps/gtest.cmake b/CMake/deps/gtest.cmake index 73d4cbe5..b5760740 100644 --- a/CMake/deps/gtest.cmake +++ b/CMake/deps/gtest.cmake @@ -10,7 +10,7 @@ if(NOT googletest_POPULATED) add_subdirectory(${googletest_SOURCE_DIR} ${googletest_BINARY_DIR} EXCLUDE_FROM_ALL) endif() -target_link_libraries(clio_tests PUBLIC clio gtest_main) +target_link_libraries(clio_tests PUBLIC clio gmock_main) target_include_directories(clio_tests PRIVATE unittests) enable_testing() diff --git a/CMakeLists.txt b/CMakeLists.txt index 2b577899..62f1c599 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -113,7 +113,8 @@ if(BUILD_TESTS) unittests/RPCErrors.cpp unittests/Backend.cpp unittests/Logger.cpp - unittests/Config.cpp) + unittests/Config.cpp + unittests/DOSGuard.cpp) include(CMake/deps/gtest.cmake) endif() diff --git a/src/config/Config.h b/src/config/Config.h index 392d83c2..06d51ad4 100644 --- a/src/config/Config.h +++ b/src/config/Config.h @@ -362,7 +362,7 @@ private: } else if constexpr (std::is_same_v) { - if (not value.is_double()) + if (not value.is_number()) has_error = true; } else if constexpr ( diff --git a/src/main/main.cpp b/src/main/main.cpp index 7755a29c..051418c5 100644 --- a/src/main/main.cpp +++ b/src/main/main.cpp @@ -201,20 +201,19 @@ try boost::asio::io_context ioc{threads}; // Rate limiter, to prevent abuse - DOSGuard dosGuard{config, ioc}; + auto sweepHandler = IntervalSweepHandler{config, ioc}; + auto dosGuard = DOSGuard{config, sweepHandler}; // Interface to the database - std::shared_ptr backend{ - Backend::make_Backend(ioc, config)}; + auto backend = Backend::make_Backend(ioc, config); // Manages clients subscribed to streams - std::shared_ptr subscriptions{ - SubscriptionManager::make_SubscriptionManager(config, backend)}; + auto subscriptions = + SubscriptionManager::make_SubscriptionManager(config, backend); // Tracks which ledgers have been validated by the // network - std::shared_ptr ledgers{ - NetworkValidatedLedgers::make_ValidatedLedgers()}; + auto ledgers = NetworkValidatedLedgers::make_ValidatedLedgers(); // Handles the connection to one or more rippled nodes. // ETL uses the balancer to extract data. diff --git a/src/webserver/DOSGuard.h b/src/webserver/DOSGuard.h index 93d4a321..4ca533be 100644 --- a/src/webserver/DOSGuard.h +++ b/src/webserver/DOSGuard.h @@ -20,57 +20,80 @@ #pragma once #include +#include + +#include #include #include #include #include -class DOSGuard +namespace clio { + +class BaseDOSGuard { - boost::asio::io_context& ctx_; - std::mutex mtx_; // protects ipFetchCount_ +public: + virtual ~BaseDOSGuard() = default; + + virtual void + clear() noexcept = 0; +}; + +/** + * @brief A simple denial of service guard used for rate limiting. + * + * @tparam SweepHandler Type of the sweep handler + */ +template +class BasicDOSGuard : public BaseDOSGuard +{ + mutable std::mutex mtx_; // protects ipFetchCount_ std::unordered_map ipFetchCount_; std::unordered_map ipConnCount_; std::unordered_set const whitelist_; + std::uint32_t const maxFetches_; - std::uint32_t const sweepInterval_; std::uint32_t const maxConnCount_; clio::Logger log_{"RPC"}; public: - DOSGuard(clio::Config const& config, boost::asio::io_context& ctx) - : ctx_{ctx} - , whitelist_{getWhitelist(config)} + /** + * @brief Constructs a new DOS guard. + * + * @param config Clio config + * @param sweepHandler Sweep handler that implements the sweeping behaviour + */ + BasicDOSGuard(clio::Config const& config, SweepHandler& sweepHandler) + : whitelist_{getWhitelist(config)} , maxFetches_{config.valueOr("dos_guard.max_fetches", 100000000u)} - , sweepInterval_{config.valueOr("dos_guard.sweep_interval", 10u)} , maxConnCount_{config.valueOr("dos_guard.max_connections", 1u)} { - createTimer(); + sweepHandler.setup(this); } - void - createTimer() - { - auto wait = std::chrono::seconds(sweepInterval_); - std::shared_ptr timer = - std::make_shared( - ctx_, std::chrono::steady_clock::now() + wait); - timer->async_wait( - [timer, this](const boost::system::error_code& error) { - clear(); - createTimer(); - }); - } - - bool - isWhiteListed(std::string const& ip) + /** + * @brief Check whether an ip address is in the whitelist or not + * + * @param ip The ip address to check + * @return true + * @return false + */ + [[nodiscard]] bool + isWhiteListed(std::string const& ip) const noexcept { return whitelist_.contains(ip); } - bool - isOk(std::string const& ip) + /** + * @brief Check whether an ip address is currently rate limited or not + * + * @param ip The ip address to check + * @return true If not rate limited + * @return false If rate limited and the request should not be processed + */ + [[nodiscard]] bool + isOk(std::string const& ip) const noexcept { if (whitelist_.contains(ip)) return true; @@ -97,8 +120,13 @@ public: return fetchesOk && connsOk; } + /** + * @brief Increment connection count for the given ip address + * + * @param ip + */ void - increment(std::string const& ip) + increment(std::string const& ip) noexcept { if (whitelist_.contains(ip)) return; @@ -106,8 +134,13 @@ public: ipConnCount_[ip]++; } + /** + * @brief Decrement connection count for the given ip address + * + * @param ip + */ void - decrement(std::string const& ip) + decrement(std::string const& ip) noexcept { if (whitelist_.contains(ip)) return; @@ -118,8 +151,20 @@ public: ipConnCount_.erase(ip); } - bool - add(std::string const& ip, uint32_t numObjects) + /** + * @brief Adds numObjects of usage for the given ip address. + * + * If the total sums up to a value equal or larger than maxFetches_ + * the operation is no longer allowed and false is returned; true is + * returned otherwise. + * + * @param ip + * @param numObjects + * @return true + * @return false + */ + [[maybe_unused]] bool + add(std::string const& ip, uint32_t numObjects) noexcept { if (whitelist_.contains(ip)) return true; @@ -136,15 +181,19 @@ public: return isOk(ip); } + /** + * @brief Instantly clears all fetch counters added by @see add(std::string + * const&, uint32_t) + */ void - clear() + clear() noexcept override { std::unique_lock lck(mtx_); ipFetchCount_.clear(); } private: - std::unordered_set const + [[nodiscard]] std::unordered_set const getWhitelist(clio::Config const& config) const { using T = std::unordered_set const; @@ -157,3 +206,72 @@ private: boost::transform_iterator(std::end(whitelist), transform)}; } }; + +/** + * @brief Sweep handler using a steady_timer and boost::asio::io_context. + */ +class IntervalSweepHandler +{ + std::chrono::milliseconds sweepInterval_; + std::reference_wrapper ctx_; + BaseDOSGuard* dosGuard_ = nullptr; + + boost::asio::steady_timer timer_{ctx_.get()}; + +public: + /** + * @brief Construct a new interval-based sweep handler + * + * @param config Clio config + * @param ctx The boost::asio::io_context + */ + IntervalSweepHandler( + clio::Config const& config, + boost::asio::io_context& ctx) + : sweepInterval_{std::max( + 1u, + static_cast( + config.valueOr("dos_guard.sweep_interval", 10.0) * 1000.0))} + , ctx_{std::ref(ctx)} + { + } + + ~IntervalSweepHandler() + { + timer_.cancel(); + } + + /** + * @brief This setup member function is called by @ref BasicDOSGuard during + * its initialization. + * + * @param guard Pointer to the dos guard + */ + void + setup(BaseDOSGuard* guard) + { + assert(dosGuard_ == nullptr); + dosGuard_ = guard; + assert(dosGuard_ != nullptr); + + createTimer(); + } + +private: + void + createTimer() + { + timer_.expires_after(sweepInterval_); + timer_.async_wait([this](boost::system::error_code const& error) { + if (error == boost::asio::error::operation_aborted) + return; + + dosGuard_->clear(); + createTimer(); + }); + } +}; + +using DOSGuard = BasicDOSGuard; + +} // namespace clio diff --git a/src/webserver/HttpBase.h b/src/webserver/HttpBase.h index a015eec6..3e5aa59e 100644 --- a/src/webserver/HttpBase.h +++ b/src/webserver/HttpBase.h @@ -115,7 +115,7 @@ class HttpBase : public util::Taggable std::shared_ptr balancer_; std::shared_ptr etl_; util::TagDecoratorFactory const& tagFactory_; - DOSGuard& dosGuard_; + clio::DOSGuard& dosGuard_; RPC::Counters& counters_; WorkQueue& workQueue_; send_lambda lambda_; @@ -172,7 +172,7 @@ public: std::shared_ptr balancer, std::shared_ptr etl, util::TagDecoratorFactory const& tagFactory, - DOSGuard& dosGuard, + clio::DOSGuard& dosGuard, RPC::Counters& counters, WorkQueue& queue, boost::beast::flat_buffer buffer) @@ -197,7 +197,7 @@ public: perfLog_.debug() << tag() << "http session closed"; } - DOSGuard& + clio::DOSGuard& dosGuard() { return dosGuard_; @@ -348,7 +348,7 @@ handle_request( std::shared_ptr balancer, std::shared_ptr etl, util::TagDecoratorFactory const& tagFactory, - DOSGuard& dosGuard, + clio::DOSGuard& dosGuard, RPC::Counters& counters, std::string const& ip, std::shared_ptr http, diff --git a/src/webserver/HttpSession.h b/src/webserver/HttpSession.h index 6459c13f..7c0eac66 100644 --- a/src/webserver/HttpSession.h +++ b/src/webserver/HttpSession.h @@ -43,7 +43,7 @@ public: std::shared_ptr balancer, std::shared_ptr etl, util::TagDecoratorFactory const& tagFactory, - DOSGuard& dosGuard, + clio::DOSGuard& dosGuard, RPC::Counters& counters, WorkQueue& queue, boost::beast::flat_buffer buffer) diff --git a/src/webserver/Listener.h b/src/webserver/Listener.h index 706e0a75..200ce313 100644 --- a/src/webserver/Listener.h +++ b/src/webserver/Listener.h @@ -51,7 +51,7 @@ class Detector std::shared_ptr balancer_; std::shared_ptr etl_; util::TagDecoratorFactory const& tagFactory_; - DOSGuard& dosGuard_; + clio::DOSGuard& dosGuard_; RPC::Counters& counters_; WorkQueue& queue_; boost::beast::flat_buffer buffer_; @@ -66,7 +66,7 @@ public: std::shared_ptr balancer, std::shared_ptr etl, util::TagDecoratorFactory const& tagFactory, - DOSGuard& dosGuard, + clio::DOSGuard& dosGuard, RPC::Counters& counters, WorkQueue& queue) : ioc_(ioc) @@ -164,7 +164,7 @@ make_websocket_session( std::shared_ptr balancer, std::shared_ptr etl, util::TagDecoratorFactory const& tagFactory, - DOSGuard& dosGuard, + clio::DOSGuard& dosGuard, RPC::Counters& counters, WorkQueue& queue) { @@ -197,7 +197,7 @@ make_websocket_session( std::shared_ptr balancer, std::shared_ptr etl, util::TagDecoratorFactory const& tagFactory, - DOSGuard& dosGuard, + clio::DOSGuard& dosGuard, RPC::Counters& counters, WorkQueue& queue) { @@ -234,7 +234,7 @@ class Listener std::shared_ptr balancer_; std::shared_ptr etl_; util::TagDecoratorFactory tagFactory_; - DOSGuard& dosGuard_; + clio::DOSGuard& dosGuard_; WorkQueue queue_; RPC::Counters counters_; @@ -250,7 +250,7 @@ public: std::shared_ptr balancer, std::shared_ptr etl, util::TagDecoratorFactory tagFactory, - DOSGuard& dosGuard) + clio::DOSGuard& dosGuard) : ioc_(ioc) , ctx_(ctx) , acceptor_(net::make_strand(ioc)) @@ -356,7 +356,7 @@ make_HttpServer( std::shared_ptr subscriptions, std::shared_ptr balancer, std::shared_ptr etl, - DOSGuard& dosGuard) + clio::DOSGuard& dosGuard) { static clio::Logger log{"WebServer"}; if (!config.contains("server")) diff --git a/src/webserver/PlainWsSession.h b/src/webserver/PlainWsSession.h index 22b6bc5d..c2b299a2 100644 --- a/src/webserver/PlainWsSession.h +++ b/src/webserver/PlainWsSession.h @@ -56,7 +56,7 @@ public: std::shared_ptr balancer, std::shared_ptr etl, util::TagDecoratorFactory const& tagFactory, - DOSGuard& dosGuard, + clio::DOSGuard& dosGuard, RPC::Counters& counters, WorkQueue& queue, boost::beast::flat_buffer&& buffer) @@ -102,7 +102,7 @@ class WsUpgrader : public std::enable_shared_from_this std::shared_ptr balancer_; std::shared_ptr etl_; util::TagDecoratorFactory const& tagFactory_; - DOSGuard& dosGuard_; + clio::DOSGuard& dosGuard_; RPC::Counters& counters_; WorkQueue& queue_; http::request req_; @@ -118,7 +118,7 @@ public: std::shared_ptr balancer, std::shared_ptr etl, util::TagDecoratorFactory const& tagFactory, - DOSGuard& dosGuard, + clio::DOSGuard& dosGuard, RPC::Counters& counters, WorkQueue& queue, boost::beast::flat_buffer&& b) @@ -145,7 +145,7 @@ public: std::shared_ptr balancer, std::shared_ptr etl, util::TagDecoratorFactory const& tagFactory, - DOSGuard& dosGuard, + clio::DOSGuard& dosGuard, RPC::Counters& counters, WorkQueue& queue, boost::beast::flat_buffer&& b, diff --git a/src/webserver/SslHttpSession.h b/src/webserver/SslHttpSession.h index 7db75220..34d3b76b 100644 --- a/src/webserver/SslHttpSession.h +++ b/src/webserver/SslHttpSession.h @@ -44,7 +44,7 @@ public: std::shared_ptr balancer, std::shared_ptr etl, util::TagDecoratorFactory const& tagFactory, - DOSGuard& dosGuard, + clio::DOSGuard& dosGuard, RPC::Counters& counters, WorkQueue& queue, boost::beast::flat_buffer buffer) diff --git a/src/webserver/SslWsSession.h b/src/webserver/SslWsSession.h index d7cc660b..00a867d7 100644 --- a/src/webserver/SslWsSession.h +++ b/src/webserver/SslWsSession.h @@ -54,7 +54,7 @@ public: std::shared_ptr balancer, std::shared_ptr etl, util::TagDecoratorFactory const& tagFactory, - DOSGuard& dosGuard, + clio::DOSGuard& dosGuard, RPC::Counters& counters, WorkQueue& queue, boost::beast::flat_buffer&& b) @@ -99,7 +99,7 @@ class SslWsUpgrader : public std::enable_shared_from_this std::shared_ptr balancer_; std::shared_ptr etl_; util::TagDecoratorFactory const& tagFactory_; - DOSGuard& dosGuard_; + clio::DOSGuard& dosGuard_; RPC::Counters& counters_; WorkQueue& queue_; http::request req_; @@ -115,7 +115,7 @@ public: std::shared_ptr balancer, std::shared_ptr etl, util::TagDecoratorFactory const& tagFactory, - DOSGuard& dosGuard, + clio::DOSGuard& dosGuard, RPC::Counters& counters, WorkQueue& queue, boost::beast::flat_buffer&& b) @@ -142,7 +142,7 @@ public: std::shared_ptr balancer, std::shared_ptr etl, util::TagDecoratorFactory const& tagFactory, - DOSGuard& dosGuard, + clio::DOSGuard& dosGuard, RPC::Counters& counters, WorkQueue& queue, boost::beast::flat_buffer&& b, diff --git a/src/webserver/WsBase.h b/src/webserver/WsBase.h index 1ee9e4d3..a9fb0d1c 100644 --- a/src/webserver/WsBase.h +++ b/src/webserver/WsBase.h @@ -121,7 +121,7 @@ class WsSession : public WsBase, std::shared_ptr balancer_; std::shared_ptr etl_; util::TagDecoratorFactory const& tagFactory_; - DOSGuard& dosGuard_; + clio::DOSGuard& dosGuard_; RPC::Counters& counters_; WorkQueue& queue_; std::mutex mtx_; @@ -155,7 +155,7 @@ public: std::shared_ptr balancer, std::shared_ptr etl, util::TagDecoratorFactory const& tagFactory, - DOSGuard& dosGuard, + clio::DOSGuard& dosGuard, RPC::Counters& counters, WorkQueue& queue, boost::beast::flat_buffer&& buffer) diff --git a/unittests/DOSGuard.cpp b/unittests/DOSGuard.cpp new file mode 100644 index 00000000..7053135b --- /dev/null +++ b/unittests/DOSGuard.cpp @@ -0,0 +1,152 @@ +//------------------------------------------------------------------------------ +/* + This file is part of clio: https://github.com/XRPLF/clio + Copyright (c) 2022, the clio developers. + + Permission to use, copy, modify, and distribute this software for any + purpose with or without fee is hereby granted, provided that the above + copyright notice and this permission notice appear in all copies. + + THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES + WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF + MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR + ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES + WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. +*/ +//============================================================================== + +#include + +#include +#include + +#include +#include + +using namespace testing; +using namespace clio; +using namespace std; +namespace json = boost::json; + +namespace { +constexpr static auto JSONData = R"JSON( + { + "dos_guard": { + "max_fetches": 100, + "sweep_interval": 1, + "max_connections": 2, + "whitelist": ["127.0.0.1"] + } + } +)JSON"; + +constexpr static auto JSONData2 = R"JSON( + { + "dos_guard": { + "max_fetches": 100, + "sweep_interval": 0.1, + "max_connections": 2, + "whitelist": ["127.0.0.1"] + } + } +)JSON"; + +constexpr static auto IP = "127.0.0.2"; + +class FakeSweepHandler +{ +private: + using guard_type = BasicDOSGuard; + guard_type* dosGuard_; + +public: + void + setup(guard_type* guard) + { + dosGuard_ = guard; + } + + void + sweep() + { + dosGuard_->clear(); + } +}; +}; // namespace + +class DOSGuardTest : public NoLoggerFixture +{ +protected: + Config cfg{json::parse(JSONData)}; + FakeSweepHandler sweepHandler; + BasicDOSGuard guard{cfg, sweepHandler}; +}; + +TEST_F(DOSGuardTest, Whitelisting) +{ + EXPECT_TRUE(guard.isWhiteListed("127.0.0.1")); + EXPECT_FALSE(guard.isWhiteListed(IP)); +} + +TEST_F(DOSGuardTest, ConnectionCount) +{ + EXPECT_TRUE(guard.isOk(IP)); + guard.increment(IP); // one connection + EXPECT_TRUE(guard.isOk(IP)); + guard.increment(IP); // two connections + EXPECT_TRUE(guard.isOk(IP)); + guard.increment(IP); // > two connections, can't connect more + EXPECT_FALSE(guard.isOk(IP)); + + guard.decrement(IP); + EXPECT_TRUE(guard.isOk(IP)); // can connect again +} + +TEST_F(DOSGuardTest, FetchCount) +{ + EXPECT_TRUE(guard.add(IP, 50)); // half of allowence + EXPECT_TRUE(guard.add(IP, 50)); // now fully charged + EXPECT_FALSE(guard.add(IP, 1)); // can't add even 1 anymore + EXPECT_FALSE(guard.isOk(IP)); + + guard.clear(); // force clear the above fetch count + EXPECT_TRUE(guard.isOk(IP)); // can fetch again +} + +TEST_F(DOSGuardTest, ClearFetchCountOnTimer) +{ + EXPECT_TRUE(guard.add(IP, 50)); // half of allowence + EXPECT_TRUE(guard.add(IP, 50)); // now fully charged + EXPECT_FALSE(guard.add(IP, 1)); // can't add even 1 anymore + EXPECT_FALSE(guard.isOk(IP)); + + sweepHandler.sweep(); // pretend sweep called from timer + EXPECT_TRUE(guard.isOk(IP)); // can fetch again +} + +template +struct BasicDOSGuardMock : public BaseDOSGuard +{ + BasicDOSGuardMock(SweepHandler& handler) + { + handler.setup(this); + } + + MOCK_METHOD(void, clear, (), (noexcept, override)); +}; + +class DOSGuardIntervalSweepHandlerTest : public SyncAsioContextTest +{ +protected: + Config cfg{json::parse(JSONData2)}; + IntervalSweepHandler sweepHandler{cfg, ctx}; + BasicDOSGuardMock guard{sweepHandler}; +}; + +TEST_F(DOSGuardIntervalSweepHandlerTest, SweepAfterInterval) +{ + EXPECT_CALL(guard, clear()).Times(Exactly(2)); + ctx.run_for(std::chrono::milliseconds(210)); +} diff --git a/unittests/util/Fixtures.h b/unittests/util/Fixtures.h index 6ca01cfb..8146cfbc 100644 --- a/unittests/util/Fixtures.h +++ b/unittests/util/Fixtures.h @@ -21,6 +21,9 @@ #include #include +#include + +#include #include #include @@ -107,3 +110,46 @@ protected: boost::log::core::get()->set_logging_enabled(false); } }; + +/** + * @brief Fixture with an embedded boost::asio context running on a thread + * + * This is meant to be used as a base for other fixtures. + */ +struct AsyncAsioContextTest : public NoLoggerFixture +{ + AsyncAsioContextTest() + { + work.emplace(ctx); // make sure ctx does not stop on its own + } + + ~AsyncAsioContextTest() + { + work.reset(); + ctx.stop(); + } + +protected: + boost::asio::io_context ctx; + +private: + std::optional work; + std::jthread runner{[this] { ctx.run(); }}; +}; + +/** + * @brief Fixture with an embedded boost::asio context that is not running by + * default but can be progressed on the calling thread + * + * Use `run_for(duration)` etc. directly on `ctx`. + * This is meant to be used as a base for other fixtures. + */ +struct SyncAsioContextTest : public NoLoggerFixture +{ + SyncAsioContextTest() + { + } + +protected: + boost::asio::io_context ctx; +};