diff --git a/src/app/ClioApplication.cpp b/src/app/ClioApplication.cpp index 59ffc0f7..c1aea46b 100644 --- a/src/app/ClioApplication.cpp +++ b/src/app/ClioApplication.cpp @@ -45,6 +45,7 @@ #include "web/Server.hpp" #include "web/dosguard/DOSGuard.hpp" #include "web/dosguard/IntervalSweepHandler.hpp" +#include "web/dosguard/Weights.hpp" #include "web/dosguard/WhitelistHandler.hpp" #include "web/ng/RPCServerHandler.hpp" #include "web/ng/Server.hpp" @@ -104,7 +105,8 @@ ClioApplication::run(bool const useNgWebServer) // Rate limiter, to prevent abuse auto whitelistHandler = web::dosguard::WhitelistHandler{config_}; - auto dosGuard = web::dosguard::DOSGuard{config_, whitelistHandler}; + auto const dosguardWeights = web::dosguard::Weights::make(config_); + auto dosGuard = web::dosguard::DOSGuard{config_, whitelistHandler, dosguardWeights}; auto sweepHandler = web::dosguard::IntervalSweepHandler{config_, ioc, dosGuard}; auto cache = data::LedgerCache{}; @@ -158,7 +160,7 @@ ClioApplication::run(bool const useNgWebServer) RPCEngineType::makeRPCEngine(config_, backend, balancer, dosGuard, workQueue, counters, handlerProvider); if (useNgWebServer or config_.get("server.__ng_web_server")) { - web::ng::RPCServerHandler handler{config_, backend, rpcEngine, etl}; + web::ng::RPCServerHandler handler{config_, backend, rpcEngine, etl, dosGuard}; auto expectedAdminVerifier = web::makeAdminVerificationStrategy(config_); if (not expectedAdminVerifier.has_value()) { @@ -176,7 +178,7 @@ ClioApplication::run(bool const useNgWebServer) httpServer->onGet("/metrics", MetricsHandler{adminVerifier}); httpServer->onGet("/health", HealthCheckHandler{}); - auto requestHandler = RequestHandler{adminVerifier, handler, dosGuard}; + auto requestHandler = RequestHandler{adminVerifier, handler}; httpServer->onPost("/", requestHandler); httpServer->onWs(std::move(requestHandler)); @@ -199,7 +201,7 @@ ClioApplication::run(bool const useNgWebServer) } // Init the web server - auto handler = std::make_shared>(config_, backend, rpcEngine, etl); + auto handler = std::make_shared>(config_, backend, rpcEngine, etl, dosGuard); auto const httpServer = web::makeHttpServer(config_, ioc, dosGuard, handler); diff --git a/src/app/WebHandlers.hpp b/src/app/WebHandlers.hpp index ac7958ac..5dacc222 100644 --- a/src/app/WebHandlers.hpp +++ b/src/app/WebHandlers.hpp @@ -147,7 +147,6 @@ class RequestHandler { util::Logger webServerLog_{"WebServer"}; std::shared_ptr adminVerifier_; std::reference_wrapper rpcHandler_; - std::reference_wrapper dosguard_; public: /** @@ -155,14 +154,9 @@ public: * * @param adminVerifier The AdminVerificationStrategy to use for verifying the connection for admin access. * @param rpcHandler The RPC handler to use for handling the request. - * @param dosguard The DOSGuardInterface to use for checking the connection. */ - RequestHandler( - std::shared_ptr adminVerifier, - RpcHandlerType& rpcHandler, - web::dosguard::DOSGuardInterface& dosguard - ) - : adminVerifier_(std::move(adminVerifier)), rpcHandler_(rpcHandler), dosguard_(dosguard) + RequestHandler(std::shared_ptr adminVerifier, RpcHandlerType& rpcHandler) + : adminVerifier_(std::move(adminVerifier)), rpcHandler_(rpcHandler) { } @@ -183,21 +177,6 @@ public: boost::asio::yield_context yield ) { - if (not dosguard_.get().request(connectionMetadata.ip())) { - auto error = rpc::makeError(rpc::RippledError::rpcSLOW_DOWN); - - if (not request.isHttp()) { - try { - auto requestJson = boost::json::parse(request.message()); - if (requestJson.is_object() && requestJson.as_object().contains("id")) - error["id"] = requestJson.as_object().at("id"); - error["request"] = request.message(); - } catch (std::exception const&) { - error["request"] = request.message(); - } - } - return web::ng::Response{boost::beast::http::status::service_unavailable, error, request}; - } LOG(webServerLog_.info()) << connectionMetadata.tag() << "Received request from ip = " << connectionMetadata.ip() << " - posting to WorkQueue"; @@ -207,20 +186,7 @@ public: }); try { - auto response = rpcHandler_(request, connectionMetadata, std::move(subscriptionContext), yield); - - if (not dosguard_.get().add(connectionMetadata.ip(), response.message().size())) { - auto jsonResponse = boost::json::parse(response.message()).as_object(); - jsonResponse["warning"] = "load"; - if (jsonResponse.contains("warnings") && jsonResponse["warnings"].is_array()) { - jsonResponse["warnings"].as_array().push_back(rpc::makeWarning(rpc::WarnRpcRateLimit)); - } else { - jsonResponse["warnings"] = boost::json::array{rpc::makeWarning(rpc::WarnRpcRateLimit)}; - } - response.setMessage(jsonResponse); - } - - return response; + return rpcHandler_(request, connectionMetadata, std::move(subscriptionContext), yield); } catch (std::exception const&) { return web::ng::Response{ boost::beast::http::status::internal_server_error, diff --git a/src/rpc/CMakeLists.txt b/src/rpc/CMakeLists.txt index 557e843d..bce9f847 100644 --- a/src/rpc/CMakeLists.txt +++ b/src/rpc/CMakeLists.txt @@ -1,3 +1,8 @@ +# Have to use RPCCenter as a separate library since it is used in util +add_library(clio_rpc_center) +target_sources(clio_rpc_center PRIVATE RPCCenter.cpp) +target_include_directories(clio_rpc_center PUBLIC "${CMAKE_SOURCE_DIR}/src") + add_library(clio_rpc) target_sources( diff --git a/src/rpc/RPCCenter.cpp b/src/rpc/RPCCenter.cpp new file mode 100644 index 00000000..74e079b1 --- /dev/null +++ b/src/rpc/RPCCenter.cpp @@ -0,0 +1,112 @@ +//------------------------------------------------------------------------------ +/* + This file is part of clio: https://github.com/XRPLF/clio + Copyright (c) 2025, the clio developers. + + Permission to use, copy, modify, and distribute this software for any + purpose with or without fee is hereby granted, provided that the above + copyright notice and this permission notice appear in all copies. + + THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES + WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF + MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR + ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES + WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. +*/ +//============================================================================== + +#include "rpc/RPCCenter.hpp" + +#include +#include + +namespace rpc { + +namespace { + +std::unordered_set const& +handledRpcs() +{ + static std::unordered_set kHANDLED_RPCS = { + "account_channels", + "account_currencies", + "account_info", + "account_lines", + "account_nfts", + "account_objects", + "account_offers", + "account_tx", + "amm_info", + "book_changes", + "book_offers", + "deposit_authorized", + "feature", + "gateway_balances", + "get_aggregate_price", + "ledger", + "ledger_data", + "ledger_entry", + "ledger_index", + "ledger_range", + "mpt_holders", + "nfts_by_issuer", + "nft_history", + "nft_buy_offers", + "nft_info", + "nft_sell_offers", + "noripple_check", + "ping", + "random", + "server_info", + "transaction_entry", + "tx", + "subscribe", + "unsubscribe", + "version", + }; + return kHANDLED_RPCS; +} + +std::unordered_set const& +forwardedRpcs() +{ + static std::unordered_set const kFORWARDED_RPCS = { + "server_definitions", + "server_state", + "submit", + "submit_multisigned", + "fee", + "ledger_closed", + "ledger_current", + "ripple_path_find", + "manifest", + "channel_authorize", + "channel_verify", + "simulate", + }; + return kFORWARDED_RPCS; +} + +} // namespace + +bool +RPCCenter::isRpcName(std::string_view s) +{ + return isHandled(s) || isForwarded(s); +} + +bool +RPCCenter::isHandled(std::string_view s) +{ + return handledRpcs().contains(s); +} + +bool +RPCCenter::isForwarded(std::string_view s) +{ + return forwardedRpcs().contains(s); +} + +} // namespace rpc diff --git a/src/rpc/RPCCenter.hpp b/src/rpc/RPCCenter.hpp new file mode 100644 index 00000000..c772c5a2 --- /dev/null +++ b/src/rpc/RPCCenter.hpp @@ -0,0 +1,61 @@ +//------------------------------------------------------------------------------ +/* + This file is part of clio: https://github.com/XRPLF/clio + Copyright (c) 2025, the clio developers. + + Permission to use, copy, modify, and distribute this software for any + purpose with or without fee is hereby granted, provided that the above + copyright notice and this permission notice appear in all copies. + + THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES + WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF + MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR + ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES + WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. +*/ +//============================================================================== + +#pragma once + +#include + +namespace rpc { + +/** + * @brief Registry of RPC commands supported by Clio + * + * The RPCCenter maintains lists of RPC commands that can be handled locally + * and those that need to be forwarded to rippled. + */ +struct RPCCenter { + /** + * @brief Checks if a string is a valid RPC command name + * + * @param s The string to check + * @return true if the string is a recognized RPC name, false otherwise + */ + static bool + isRpcName(std::string_view s); + + /** + * @brief Checks if a string is a RPC command handled by Clio without forwarding to rippled + * + * @param s The string to check + * @return true if the string is a handled RPC command, false otherwise + */ + static bool + isHandled(std::string_view s); + + /** + * @brief Checks if a string is a RPC command that will be forwarded to rippled + * + * @param s The string to check + * @return true if the string is a forwarded RPC command, false otherwise + */ + static bool + isForwarded(std::string_view s); +}; + +} // namespace rpc diff --git a/src/rpc/common/impl/ForwardingProxy.hpp b/src/rpc/common/impl/ForwardingProxy.hpp index 2c07b79a..8bf96ceb 100644 --- a/src/rpc/common/impl/ForwardingProxy.hpp +++ b/src/rpc/common/impl/ForwardingProxy.hpp @@ -21,6 +21,7 @@ #include "etlng/LoadBalancerInterface.hpp" #include "rpc/Errors.hpp" +#include "rpc/RPCCenter.hpp" #include "rpc/RPCHelpers.hpp" #include "rpc/common/Types.hpp" #include "util/log/Logger.hpp" @@ -106,22 +107,7 @@ public: bool isProxied(std::string const& method) const { - static std::unordered_set const kPROXIED_COMMANDS{ - "server_definitions", - "server_state", - "submit", - "submit_multisigned", - "fee", - "ledger_closed", - "ledger_current", - "ripple_path_find", - "manifest", - "channel_authorize", - "channel_verify", - "simulate", - }; - - return kPROXIED_COMMANDS.contains(method); + return RPCCenter::isForwarded(method); } private: diff --git a/src/rpc/common/impl/HandlerProvider.cpp b/src/rpc/common/impl/HandlerProvider.cpp index c05e5cd5..2f360a9a 100644 --- a/src/rpc/common/impl/HandlerProvider.cpp +++ b/src/rpc/common/impl/HandlerProvider.cpp @@ -66,6 +66,7 @@ #include #include #include +#include namespace rpc::impl { @@ -139,4 +140,13 @@ ProductionHandlerProvider::isClioOnly(std::string const& command) const return handlerMap_.contains(command) && handlerMap_.at(command).isClioOnly; } +std::unordered_set +ProductionHandlerProvider::handlerNames() const +{ + std::unordered_set result; + for (auto const& [name, handler] : handlerMap_) + result.insert(name); + return result; +} + } // namespace rpc::impl diff --git a/src/rpc/common/impl/HandlerProvider.hpp b/src/rpc/common/impl/HandlerProvider.hpp index 2c07428c..0b54be3c 100644 --- a/src/rpc/common/impl/HandlerProvider.hpp +++ b/src/rpc/common/impl/HandlerProvider.hpp @@ -33,6 +33,7 @@ #include #include #include +#include namespace rpc { class Counters; @@ -67,6 +68,9 @@ public: bool isClioOnly(std::string const& command) const override; + + std::unordered_set + handlerNames() const; }; } // namespace rpc::impl diff --git a/src/util/CMakeLists.txt b/src/util/CMakeLists.txt index 51d44765..838b5b11 100644 --- a/src/util/CMakeLists.txt +++ b/src/util/CMakeLists.txt @@ -24,6 +24,7 @@ target_sources( ResponseExpirationCache.cpp SignalsHandler.cpp StopHelper.cpp + StringHash.cpp Taggable.cpp TerminationHandler.cpp TimeUtils.cpp @@ -41,7 +42,7 @@ target_sources( # This must be above the target_link_libraries call otherwise backtrace doesn't work if ("${san}" STREQUAL "") - target_link_libraries(clio_util PUBLIC Boost::stacktrace_backtrace dl libbacktrace::libbacktrace) + target_link_libraries(clio_util PUBLIC Boost::stacktrace_backtrace dl libbacktrace::libbacktrace clio_rpc_center) endif () target_link_libraries( diff --git a/src/util/StringHash.cpp b/src/util/StringHash.cpp new file mode 100644 index 00000000..3f26feb2 --- /dev/null +++ b/src/util/StringHash.cpp @@ -0,0 +1,46 @@ +//------------------------------------------------------------------------------ +/* + This file is part of clio: https://github.com/XRPLF/clio + Copyright (c) 2025, the clio developers. + + Permission to use, copy, modify, and distribute this software for any + purpose with or without fee is hereby granted, provided that the above + copyright notice and this permission notice appear in all copies. + + THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES + WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF + MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR + ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES + WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. +*/ +//============================================================================== + +#include "util/StringHash.hpp" + +#include +#include +#include + +namespace util { + +size_t +StringHash::operator()(char const* str) const +{ + return hash_type{}(str); +} + +size_t +StringHash::operator()(std::string_view str) const +{ + return hash_type{}(str); +} + +size_t +StringHash::operator()(std::string const& str) const +{ + return hash_type{}(str); +} + +} // namespace util diff --git a/src/util/StringHash.hpp b/src/util/StringHash.hpp new file mode 100644 index 00000000..862c8ed8 --- /dev/null +++ b/src/util/StringHash.hpp @@ -0,0 +1,65 @@ +//------------------------------------------------------------------------------ +/* + This file is part of clio: https://github.com/XRPLF/clio + Copyright (c) 2025, the clio developers. + + Permission to use, copy, modify, and distribute this software for any + purpose with or without fee is hereby granted, provided that the above + copyright notice and this permission notice appear in all copies. + + THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES + WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF + MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR + ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES + WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. +*/ +//============================================================================== + +#pragma once + +#include +#include +#include +#include + +namespace util { + +/** + * @brief A string hash functor that provides transparent hash operations for various string types. + * + * This hash functor can be used with unordered containers to enable heterogeneous lookups + * for different string-like types without unnecessary conversions. It supports C-style strings, + * string views, and standard strings. + */ +struct StringHash { + using hash_type = std::hash; + using is_transparent = void; ///< Enables heterogeneous lookup + + /** + * @brief Computes the hash of a C-style string. + * @param str Null-terminated C-style string to hash + * @return Size_t hash value + */ + std::size_t + operator()(char const* str) const; + + /** + * @brief Computes the hash of a string_view. + * @param str String view to hash + * @return Size_t hash value + */ + std::size_t + operator()(std::string_view str) const; + + /** + * @brief Computes the hash of a standard string. + * @param str String to hash + * @return Size_t hash value + */ + std::size_t + operator()(std::string const& str) const; +}; + +} // namespace util diff --git a/src/util/newconfig/ConfigConstraints.cpp b/src/util/newconfig/ConfigConstraints.cpp index 2a0f8fb6..41d5210b 100644 --- a/src/util/newconfig/ConfigConstraints.cpp +++ b/src/util/newconfig/ConfigConstraints.cpp @@ -19,6 +19,7 @@ #include "util/newconfig/ConfigConstraints.hpp" +#include "rpc/RPCCenter.hpp" #include "util/newconfig/Error.hpp" #include "util/newconfig/Types.hpp" @@ -103,4 +104,22 @@ PositiveDouble::checkValueImpl(Value const& num) const return Error{"Double number must be greater than or equal to 0"}; } +std::optional +RpcNameConstraint::checkTypeImpl(Value const& value) const +{ + if (not std::holds_alternative(value)) + return Error{"RPC command name must be a string"}; + return std::nullopt; +} + +std::optional +RpcNameConstraint::checkValueImpl(Value const& value) const +{ + auto const str = std::get(value); + if (not rpc::RPCCenter::isRpcName(str)) + return Error{"Invalid RPC command name"}; + + return std::nullopt; +} + } // namespace util::config diff --git a/src/util/newconfig/ConfigConstraints.hpp b/src/util/newconfig/ConfigConstraints.hpp index 752b8b6d..27327bea 100644 --- a/src/util/newconfig/ConfigConstraints.hpp +++ b/src/util/newconfig/ConfigConstraints.hpp @@ -427,6 +427,41 @@ private: } }; +/** + * @brief A constraint to ensure the value is a valid RPC command name. + */ +class RpcNameConstraint final : public Constraint { +private: + /** + * @brief Check if the type of the value is correct for this specific constraint. + * + * @param value The type to be checked + * @return An Error object if the constraint is not met, nullopt otherwise + */ + [[nodiscard]] std::optional + checkTypeImpl(Value const& value) const override; + + /** + * @brief Check if the value is a valid RPC command name. + * + * @param value The value to check + * @return An Error object if the constraint is not met, nullopt otherwise + */ + [[nodiscard]] std::optional + checkValueImpl(Value const& value) const override; + + /** + * @brief Prints to the output stream for this specific constraint. + * + * @param stream The output stream + */ + void + print(std::ostream& stream) const override + { + stream << "Checks whether provided RPC name is valid"; + } +}; + static constinit PortConstraint gValidatePort{}; static constinit ValidIPConstraint gValidateIp{}; @@ -448,6 +483,8 @@ static constinit NumberValueConstraint gValidateReplicationFactor{0, s static constinit NumberValueConstraint gValidateUint16{1, std::numeric_limits::max()}; static constinit NumberValueConstraint gValidateUint32{1, std::numeric_limits::max()}; +static constinit NumberValueConstraint gValidateNonNegativeUint32{0, std::numeric_limits::max()}; static constinit NumberValueConstraint gValidateApiVersion{rpc::kAPI_VERSION_MIN, rpc::kAPI_VERSION_MAX}; +static constinit RpcNameConstraint gRpcNameConstraint{}; } // namespace util::config diff --git a/src/util/newconfig/ConfigDefinition.hpp b/src/util/newconfig/ConfigDefinition.hpp index 8528b1f2..9d3b7e29 100644 --- a/src/util/newconfig/ConfigDefinition.hpp +++ b/src/util/newconfig/ConfigDefinition.hpp @@ -314,6 +314,15 @@ static ClioConfigDefinition gClioConfig = ClioConfigDefinition{ {"dos_guard.max_requests", ConfigValue{ConfigType::Integer}.defaultValue(20u).withConstraint(gValidateUint32)}, {"dos_guard.sweep_interval", ConfigValue{ConfigType::Double}.defaultValue(1.0).withConstraint(gValidatePositiveDouble)}, + {"dos_guard.__ng_default_weight", + ConfigValue{ConfigType::Integer}.defaultValue(1).withConstraint(gValidateNonNegativeUint32)}, + {"dos_guard.__ng_weights.[].method", Array{ConfigValue{ConfigType::String}.withConstraint(gRpcNameConstraint)}}, + {"dos_guard.__ng_weights.[].weight", + Array{ConfigValue{ConfigType::Integer}.withConstraint(gValidateNonNegativeUint32)}}, + {"dos_guard.__ng_weights.[].weight_ledger_current", + Array{ConfigValue{ConfigType::Integer}.optional().withConstraint(gValidateNonNegativeUint32)}}, + {"dos_guard.__ng_weights.[].weight_ledger_validated", + Array{ConfigValue{ConfigType::Integer}.optional().withConstraint(gValidateNonNegativeUint32)}}, {"workers", ConfigValue{ConfigType::Integer} diff --git a/src/web/CMakeLists.txt b/src/web/CMakeLists.txt index 251fda1b..2b62bbe4 100644 --- a/src/web/CMakeLists.txt +++ b/src/web/CMakeLists.txt @@ -5,6 +5,7 @@ target_sources( PRIVATE AdminVerificationStrategy.cpp dosguard/DOSGuard.cpp dosguard/IntervalSweepHandler.cpp + dosguard/Weights.cpp dosguard/WhitelistHandler.cpp ng/Connection.cpp ng/impl/ErrorHandling.cpp diff --git a/src/web/RPCServerHandler.hpp b/src/web/RPCServerHandler.hpp index 2d149b3b..34ab0394 100644 --- a/src/web/RPCServerHandler.hpp +++ b/src/web/RPCServerHandler.hpp @@ -31,6 +31,7 @@ #include "util/Taggable.hpp" #include "util/log/Logger.hpp" #include "util/newconfig/ConfigDefinition.hpp" +#include "web/dosguard/DOSGuardInterface.hpp" #include "web/impl/ErrorHandling.hpp" #include "web/interface/ConnectionBase.hpp" @@ -66,6 +67,7 @@ class RPCServerHandler { std::shared_ptr const etl_; util::TagDecoratorFactory const tagFactory_; rpc::impl::ProductionAPIVersionParser apiVersionParser_; // can be injected if needed + std::reference_wrapper dosguard_; util::Logger log_{"RPC"}; util::Logger perfLog_{"Performance"}; @@ -78,18 +80,21 @@ public: * @param backend The backend to use * @param rpcEngine The RPC engine to use * @param etl The ETL to use + * @param dosguard The DOS guard service to use for request rate limiting */ RPCServerHandler( util::config::ClioConfigDefinition const& config, std::shared_ptr const& backend, std::shared_ptr const& rpcEngine, - std::shared_ptr const& etl + std::shared_ptr const& etl, + web::dosguard::DOSGuardInterface& dosguard ) : backend_(backend) , rpcEngine_(rpcEngine) , etl_(etl) , tagFactory_(config) , apiVersionParser_(config.getObject("api_version")) + , dosguard_(dosguard) { } @@ -102,6 +107,11 @@ public: void operator()(std::string const& request, std::shared_ptr const& connection) { + if (not dosguard_.get().isOk(connection->clientIp)) { + connection->sendSlowDown(request); + return; + } + try { auto req = boost::json::parse(request).as_object(); LOG(perfLog_.debug()) << connection->tag() << "Adding to work queue"; @@ -109,6 +119,11 @@ public: if (not connection->upgraded and shouldReplaceParams(req)) req[JS(params)] = boost::json::array({boost::json::object{}}); + if (not dosguard_.get().request(connection->clientIp, req)) { + connection->sendSlowDown(request); + return; + } + if (!rpcEngine_->post( [this, request = std::move(req), connection](boost::asio::yield_context yield) mutable { handleRequest(yield, std::move(request), connection); diff --git a/src/web/dosguard/DOSGuard.cpp b/src/web/dosguard/DOSGuard.cpp index 3a997fa1..8698a282 100644 --- a/src/web/dosguard/DOSGuard.cpp +++ b/src/web/dosguard/DOSGuard.cpp @@ -24,8 +24,11 @@ #include "util/newconfig/ArrayView.hpp" #include "util/newconfig/ConfigDefinition.hpp" #include "util/newconfig/ValueView.hpp" +#include "web/dosguard/WeightsInterface.hpp" #include "web/dosguard/WhitelistHandlerInterface.hpp" +#include + #include #include #include @@ -37,8 +40,13 @@ using namespace util::config; namespace web::dosguard { -DOSGuard::DOSGuard(ClioConfigDefinition const& config, WhitelistHandlerInterface const& whitelistHandler) +DOSGuard::DOSGuard( + ClioConfigDefinition const& config, + WhitelistHandlerInterface const& whitelistHandler, + WeightsInterface const& weights +) : whitelistHandler_{std::cref(whitelistHandler)} + , weights_(weights) , maxFetches_{config.get("dos_guard.max_fetches")} , maxConnCount_{config.get("dos_guard.max_connections")} , maxRequestCount_{config.get("dos_guard.max_requests")} @@ -59,8 +67,8 @@ DOSGuard::isOk(std::string const& ip) const noexcept { auto lock = mtx_.lock(); - if (lock->ipState.find(ip) != lock->ipState.end()) { - auto [transferredByte, requests] = lock->ipState.at(ip); + if (auto const it = lock->ipState.find(ip); it != lock->ipState.end()) { + auto const [transferredByte, requests] = it->second; if (transferredByte > maxFetches_ || requests > maxRequestCount_) { LOG(log_.warn()) << "Dosguard: Client surpassed the rate limit. ip = " << ip << " Transferred Byte: " << transferredByte << "; Requests: " << requests; @@ -115,14 +123,16 @@ DOSGuard::add(std::string const& ip, uint32_t numObjects) noexcept } [[maybe_unused]] bool -DOSGuard::request(std::string const& ip) noexcept +DOSGuard::request(std::string const& ip, boost::json::object const& request) { if (whitelistHandler_.get().isWhiteListed(ip)) return true; + auto const weight = weights_.get().requestWeight(request); + { auto lock = mtx_.lock(); - lock->ipState[ip].requestsCount++; + lock->ipState[ip].requestsCount += weight; } return isOk(ip); diff --git a/src/web/dosguard/DOSGuard.hpp b/src/web/dosguard/DOSGuard.hpp index e4ace5d3..78ecb2b9 100644 --- a/src/web/dosguard/DOSGuard.hpp +++ b/src/web/dosguard/DOSGuard.hpp @@ -23,10 +23,12 @@ #include "util/log/Logger.hpp" #include "util/newconfig/ConfigDefinition.hpp" #include "web/dosguard/DOSGuardInterface.hpp" +#include "web/dosguard/WeightsInterface.hpp" #include "web/dosguard/WhitelistHandlerInterface.hpp" #include #include +#include #include #include @@ -59,6 +61,7 @@ class DOSGuard : public DOSGuardInterface { util::Mutex mtx_; std::reference_wrapper whitelistHandler_; + std::reference_wrapper weights_; std::uint32_t const maxFetches_; std::uint32_t const maxConnCount_; @@ -71,8 +74,13 @@ public: * * @param config Clio config * @param whitelistHandler Whitelist handler that checks whitelist for IP addresses + * @param weights API methods weights */ - DOSGuard(util::config::ClioConfigDefinition const& config, WhitelistHandlerInterface const& whitelistHandler); + DOSGuard( + util::config::ClioConfigDefinition const& config, + WhitelistHandlerInterface const& whitelistHandler, + WeightsInterface const& weights + ); /** * @brief Check whether an ip address is in the whitelist or not. @@ -133,11 +141,12 @@ public: * returned otherwise. * * @param ip + * @param request The request as json object * @return true * @return false */ [[maybe_unused]] bool - request(std::string const& ip) noexcept override; + request(std::string const& ip, boost::json::object const& request) override; /** * @brief Instantly clears all fetch counters added by @see add(std::string const&, uint32_t). diff --git a/src/web/dosguard/DOSGuardInterface.hpp b/src/web/dosguard/DOSGuardInterface.hpp index 1eee27b6..684087c9 100644 --- a/src/web/dosguard/DOSGuardInterface.hpp +++ b/src/web/dosguard/DOSGuardInterface.hpp @@ -19,6 +19,8 @@ #pragma once +#include + #include #include #include @@ -99,12 +101,13 @@ public: * * * @param ip + * @param request The request as json object * @return If the total sums up to a value equal or larger than maxRequestCount_ * the operation is no longer allowed and false is returned; true is * returned otherwise. */ [[maybe_unused]] virtual bool - request(std::string const& ip) noexcept = 0; + request(std::string const& ip, boost::json::object const& request) = 0; }; } // namespace web::dosguard diff --git a/src/web/dosguard/Weights.cpp b/src/web/dosguard/Weights.cpp new file mode 100644 index 00000000..bed67ffe --- /dev/null +++ b/src/web/dosguard/Weights.cpp @@ -0,0 +1,106 @@ +//------------------------------------------------------------------------------ +/* + This file is part of clio: https://github.com/XRPLF/clio + Copyright (c) 2025, the clio developers. + + Permission to use, copy, modify, and distribute this software for any + purpose with or without fee is hereby granted, provided that the above + copyright notice and this permission notice appear in all copies. + + THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES + WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF + MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR + ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES + WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. +*/ +//============================================================================== + +#include "web/dosguard/Weights.hpp" + +#include "rpc/JS.hpp" +#include "util/Assert.hpp" +#include "util/newconfig/ArrayView.hpp" +#include "util/newconfig/ConfigDefinition.hpp" + +#include +#include + +#include +#include +#include +#include +#include +#include + +namespace web::dosguard { + +Weights::Weights(size_t defaultWeight, std::unordered_map weights) + : defaultWeight_(defaultWeight), weights_(std::move_iterator(weights.begin()), std::move_iterator(weights.end())) +{ +} + +Weights +Weights::make(util::config::ClioConfigDefinition const& config) +{ + std::unordered_map weights; + auto const configWeights = config.getArray("dos_guard.__ng_weights"); + for (size_t i = 0; i < configWeights.size(); ++i) { + auto const w = configWeights.objectAt(i); + Weights::Entry const entry{ + .weight = w.get("weight"), + .weightLedgerCurrent = w.maybeValue("weight_ledger_current"), + .weightLedgerValidated = w.maybeValue("weight_ledger_validated"), + }; + weights.emplace(w.get("method"), entry); + } + return Weights{config.get("dos_guard.__ng_default_weight"), std::move(weights)}; +} + +size_t +Weights::requestWeight(boost::json::object const& request) const +{ + if (not((request.contains(JS(method)) and request.at(JS(method)).is_string()) or + (request.contains(JS(command)) and request.at(JS(command)).is_string()))) { + return defaultWeight_; + } + + std::string_view cmd = + request.contains(JS(method)) ? request.at(JS(method)).as_string() : request.at(JS(command)).as_string(); + + auto it = weights_.find(cmd); + if (it == weights_.end()) { + return defaultWeight_; + } + + auto const& entry = it->second; + + boost::json::value const* ledgerIndex = nullptr; + if (request.contains(JS(ledger_index))) { + ledgerIndex = &request.at(JS(ledger_index)); + } else if (request.contains(JS(params))) { + ASSERT( + request.at(JS(params)).is_array() and not request.at(JS(params)).as_array().empty() and + request.at(JS(params)).as_array().at(0).is_object(), + "params should be [{{}}]" + ); + if (auto const& params = request.at(JS(params)).as_array().at(0).as_object(); + params.contains(JS(ledger_index))) { + ledgerIndex = ¶ms.at(JS(ledger_index)); + } + } + + if (ledgerIndex != nullptr and ledgerIndex->is_string()) { + auto const& ledgerIndexString = ledgerIndex->as_string(); + if (ledgerIndexString == JS(validated)) { + return entry.weightLedgerValidated.value_or(entry.weight); + } + if (ledgerIndexString == JS(current)) { + return entry.weightLedgerCurrent.value_or(entry.weight); + } + } + return entry.weight; +} + +} // namespace web::dosguard diff --git a/src/web/dosguard/Weights.hpp b/src/web/dosguard/Weights.hpp new file mode 100644 index 00000000..da041ec8 --- /dev/null +++ b/src/web/dosguard/Weights.hpp @@ -0,0 +1,88 @@ +//------------------------------------------------------------------------------ +/* + This file is part of clio: https://github.com/XRPLF/clio + Copyright (c) 2025, the clio developers. + + Permission to use, copy, modify, and distribute this software for any + purpose with or without fee is hereby granted, provided that the above + copyright notice and this permission notice appear in all copies. + + THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES + WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF + MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR + ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES + WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. +*/ +//============================================================================== + +#pragma once + +#include "util/StringHash.hpp" +#include "util/newconfig/ConfigDefinition.hpp" +#include "web/dosguard/WeightsInterface.hpp" + +#include + +#include +#include +#include +#include +#include + +namespace web::dosguard { + +/** + * @brief Implementation of WeightsInterface that manages command weights for DosGuard. + * + * This class provides a mechanism to assign different weights to API commands + * for the purpose of DOS protection calculations. Commands can have specific weights, + * or fall back to a default weight. + */ +class Weights : public WeightsInterface { +public: + /** + * @brief Structure representing weight configuration for a command. + * + * Contains the base weight and optional specialized weights for different ledger specifications. + */ + struct Entry { + size_t weight; + std::optional weightLedgerCurrent; + std::optional weightLedgerValidated; + }; + +private: + size_t defaultWeight_; + std::unordered_map> weights_; + +public: + /** + * @brief Construct a new Weights object + * + * @param defaultWeight The default weight to use when a command-specific weight is not defined + * @param weights Map of command names to their specific weights + */ + Weights(size_t defaultWeight, std::unordered_map weights); + + /** + * @brief Create a Weights object from configuration + * + * @param config The application configuration + * @return Weights instance initialized with values from configuration + */ + static Weights + make(util::config::ClioConfigDefinition const& config); + + /** + * @brief Get the weight assigned to a specific command + * + * @param request Json request + * @return size_t The weight value (specific weight if defined, otherwise default weight) + */ + size_t + requestWeight(boost::json::object const& request) const override; +}; + +} // namespace web::dosguard diff --git a/src/web/dosguard/WeightsInterface.hpp b/src/web/dosguard/WeightsInterface.hpp new file mode 100644 index 00000000..916a25f4 --- /dev/null +++ b/src/web/dosguard/WeightsInterface.hpp @@ -0,0 +1,48 @@ +//------------------------------------------------------------------------------ +/* + This file is part of clio: https://github.com/XRPLF/clio + Copyright (c) 2025, the clio developers. + + Permission to use, copy, modify, and distribute this software for any + purpose with or without fee is hereby granted, provided that the above + copyright notice and this permission notice appear in all copies. + + THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES + WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF + MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR + ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES + WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. +*/ +//============================================================================== + +#pragma once + +#include + +#include + +namespace web::dosguard { + +/** + * @brief Interface for determining request weights in DOS protection. + * + * This interface defines the contract for classes that calculate weights for incoming + * requests, which is used for DOS protection mechanisms. + */ +class WeightsInterface { +public: + virtual ~WeightsInterface() = default; + + /** + * @brief Calculate the weight of a request. + * + * @param request The JSON object representing the request + * @return The calculated weight of the request + */ + virtual size_t + requestWeight(boost::json::object const& request) const = 0; +}; + +} // namespace web::dosguard diff --git a/src/web/impl/HttpBase.hpp b/src/web/impl/HttpBase.hpp index cfc257e8..5f3fa26c 100644 --- a/src/web/impl/HttpBase.hpp +++ b/src/web/impl/HttpBase.hpp @@ -240,19 +240,7 @@ public: return sender_(httpResponse(http::status::bad_request, "text/html", "Expected a POST request")); } - // to avoid overwhelm work queue, the request limit check should be - // before posting to queue the web socket creation will be guarded via - // connection limit - if (!dosGuard_.get().request(clientIp)) { - // TODO: this looks like it could be useful to count too in the future - return sender_(httpResponse( - http::status::service_unavailable, - "text/plain", - boost::json::serialize(rpc::makeError(rpc::RippledError::rpcSLOW_DOWN)) - )); - } - - LOG(log_.info()) << tag() << "Received request from ip = " << clientIp << " - posting to WorkQueue"; + LOG(log_.info()) << tag() << "Received request from ip = " << clientIp; try { (*handler_)(req_.body(), derived().shared_from_this()); @@ -265,6 +253,16 @@ public: } } + void + sendSlowDown(std::string const&) override + { + sender_(httpResponse( + http::status::service_unavailable, + "text/plain", + boost::json::serialize(rpc::makeError(rpc::RippledError::rpcSLOW_DOWN)) + )); + } + /** * @brief Send a response to the client * The message length will be added to the DOSGuard, if the limit is reached, a warning will be added to the diff --git a/src/web/impl/WsBase.hpp b/src/web/impl/WsBase.hpp index bdd6d860..195ead02 100644 --- a/src/web/impl/WsBase.hpp +++ b/src/web/impl/WsBase.hpp @@ -164,6 +164,12 @@ public: doWrite(); } + void + sendSlowDown(std::string const& request) override + { + sendError(rpc::RippledError::rpcSLOW_DOWN, request); + } + /** * @brief Send a message to the client * @param msg The message to send, it will keep the string alive until it is sent. It is useful when we have @@ -280,36 +286,33 @@ public: LOG(perfLog_.info()) << tag() << "Received request from ip = " << this->clientIp; - auto sendError = [this](auto error, std::string&& requestStr) { - auto e = rpc::makeError(error); - - try { - auto request = boost::json::parse(requestStr); - if (request.is_object() && request.as_object().contains("id")) - e["id"] = request.as_object().at("id"); - e["request"] = std::move(request); - } catch (std::exception const&) { - e["request"] = std::move(requestStr); - } - - this->send(std::make_shared(boost::json::serialize(e))); - }; - std::string requestStr{static_cast(buffer_.data().data()), buffer_.size()}; - // dosGuard served request++ and check ip address - if (!dosGuard_.get().request(clientIp)) { - // TODO: could be useful to count in counters in the future too - sendError(rpc::RippledError::rpcSLOW_DOWN, std::move(requestStr)); - } else { - try { - (*handler_)(requestStr, shared_from_this()); - } catch (std::exception const&) { - sendError(rpc::RippledError::rpcINTERNAL, std::move(requestStr)); - } + try { + (*handler_)(requestStr, shared_from_this()); + } catch (std::exception const&) { + sendError(rpc::RippledError::rpcINTERNAL, std::move(requestStr)); } doRead(); } + +private: + void + sendError(rpc::RippledError error, std::string requestStr) + { + auto e = rpc::makeError(error); + + try { + auto request = boost::json::parse(requestStr); + if (request.is_object() && request.as_object().contains("id")) + e["id"] = request.as_object().at("id"); + e["request"] = std::move(request); + } catch (std::exception const&) { + e["request"] = std::move(requestStr); + } + + this->send(std::make_shared(boost::json::serialize(e))); + } }; } // namespace web::impl diff --git a/src/web/interface/ConnectionBase.hpp b/src/web/interface/ConnectionBase.hpp index 839b1d9b..894ce0de 100644 --- a/src/web/interface/ConnectionBase.hpp +++ b/src/web/interface/ConnectionBase.hpp @@ -82,6 +82,13 @@ public: throw std::logic_error("web server can not send the shared payload"); } + /** + * @brief Send a "slow down" error response to the client. + * + * @param request The original request that triggered the rate limiting + */ + virtual void + sendSlowDown(std::string const& request) = 0; /** * @brief Get the subscription context for this connection. * diff --git a/src/web/ng/RPCServerHandler.hpp b/src/web/ng/RPCServerHandler.hpp index 05a1f806..64db0062 100644 --- a/src/web/ng/RPCServerHandler.hpp +++ b/src/web/ng/RPCServerHandler.hpp @@ -33,6 +33,7 @@ #include "util/Taggable.hpp" #include "util/log/Logger.hpp" #include "web/SubscriptionContextInterface.hpp" +#include "web/dosguard/DOSGuardInterface.hpp" #include "web/ng/Connection.hpp" #include "web/ng/Request.hpp" #include "web/ng/Response.hpp" @@ -70,6 +71,7 @@ class RPCServerHandler { std::shared_ptr const backend_; std::shared_ptr const rpcEngine_; std::shared_ptr const etl_; + std::reference_wrapper dosguard_; util::TagDecoratorFactory const tagFactory_; rpc::impl::ProductionAPIVersionParser apiVersionParser_; // can be injected if needed @@ -84,16 +86,19 @@ public: * @param backend The backend to use * @param rpcEngine The RPC engine to use * @param etl The ETL to use + * @param dosguard The DOS guard service to use for request rate limiting */ RPCServerHandler( util::config::ClioConfigDefinition const& config, std::shared_ptr const& backend, std::shared_ptr const& rpcEngine, - std::shared_ptr const& etl + std::shared_ptr const& etl, + dosguard::DOSGuardInterface& dosguard ) : backend_(backend) , rpcEngine_(rpcEngine) , etl_(etl) + , dosguard_(dosguard) , tagFactory_(config) , apiVersionParser_(config.getObject("api_version")) { @@ -116,6 +121,10 @@ public: boost::asio::yield_context yield ) { + if (not dosguard_.get().isOk(connectionMetadata.ip())) { + return makeSlowDownResponse(request, std::nullopt); + } + std::optional response; util::CoroutineGroup coroutineGroup{yield, 1}; auto const onTaskComplete = coroutineGroup.registerForeign(yield); @@ -142,18 +151,23 @@ public: } } else { auto parsedObject = std::move(parsedRequest).as_object(); - LOG(perfLog_.debug()) << connectionMetadata.tag() << "Adding to work queue"; - if (not connectionMetadata.wasUpgraded() and shouldReplaceParams(parsedObject)) - parsedObject[JS(params)] = boost::json::array({boost::json::object{}}); + if (not dosguard_.get().request(connectionMetadata.ip(), parsedObject)) { + response = makeSlowDownResponse(request, parsedObject); + } else { + LOG(perfLog_.debug()) << connectionMetadata.tag() << "Adding to work queue"; - response = handleRequest( - innerYield, - request, - std::move(parsedObject), - connectionMetadata, - std::move(subscriptionContext) - ); + if (not connectionMetadata.wasUpgraded() and shouldReplaceParams(parsedObject)) + parsedObject[JS(params)] = boost::json::array({boost::json::object{}}); + + response = handleRequest( + innerYield, + request, + std::move(parsedObject), + connectionMetadata, + std::move(subscriptionContext) + ); + } } } catch (std::exception const& ex) { LOG(perfLog_.error()) << connectionMetadata.tag() << "Caught exception: " << ex.what(); @@ -177,6 +191,11 @@ public: // Put the coroutine to sleep until the foreign task is done coroutineGroup.asyncWait(yield); ASSERT(response.has_value(), "Woke up coroutine without setting response"); + + if (not dosguard_.get().add(connectionMetadata.ip(), response->message().size())) { + response->setMessage(makeLoadWarning(*response)); + } + return std::move(response).value(); } @@ -316,6 +335,39 @@ private: } } + static Response + makeSlowDownResponse(Request const& request, std::optional requestJson) + { + auto error = rpc::makeError(rpc::RippledError::rpcSLOW_DOWN); + + if (not request.isHttp()) { + try { + if (not requestJson.has_value()) { + requestJson = boost::json::parse(request.message()); + } + if (requestJson->is_object() && requestJson->as_object().contains("id")) + error["id"] = requestJson->as_object().at("id"); + error["request"] = request.message(); + } catch (std::exception const&) { + error["request"] = request.message(); + } + } + return web::ng::Response{boost::beast::http::status::service_unavailable, error, request}; + } + + static boost::json::object + makeLoadWarning(Response const& response) + { + auto jsonResponse = boost::json::parse(response.message()).as_object(); + jsonResponse["warning"] = "load"; + if (jsonResponse.contains("warnings") && jsonResponse["warnings"].is_array()) { + jsonResponse["warnings"].as_array().push_back(rpc::makeWarning(rpc::WarnRpcRateLimit)); + } else { + jsonResponse["warnings"] = boost::json::array{rpc::makeWarning(rpc::WarnRpcRateLimit)}; + } + return jsonResponse; + } + bool shouldReplaceParams(boost::json::object const& req) const { diff --git a/src/web/ng/impl/ConnectionHandler.cpp b/src/web/ng/impl/ConnectionHandler.cpp index adfd8bc1..2bcbfb65 100644 --- a/src/web/ng/impl/ConnectionHandler.cpp +++ b/src/web/ng/impl/ConnectionHandler.cpp @@ -86,24 +86,6 @@ handleWsRequest( } // namespace -size_t -ConnectionHandler::StringHash::operator()(char const* str) const -{ - return hash_type{}(str); -} - -size_t -ConnectionHandler::StringHash::operator()(std::string_view str) const -{ - return hash_type{}(str); -} - -size_t -ConnectionHandler::StringHash::operator()(std::string const& str) const -{ - return hash_type{}(str); -} - ConnectionHandler::ConnectionHandler( ProcessingPolicy processingPolicy, std::optional maxParallelRequests, diff --git a/src/web/ng/impl/ConnectionHandler.hpp b/src/web/ng/impl/ConnectionHandler.hpp index cc1fc641..e21b07bd 100644 --- a/src/web/ng/impl/ConnectionHandler.hpp +++ b/src/web/ng/impl/ConnectionHandler.hpp @@ -20,6 +20,7 @@ #pragma once #include "util/StopHelper.hpp" +#include "util/StringHash.hpp" #include "util/Taggable.hpp" #include "util/log/Logger.hpp" #include "util/prometheus/Gauge.hpp" @@ -44,7 +45,6 @@ #include #include #include -#include #include namespace web::ng::impl { @@ -52,20 +52,7 @@ namespace web::ng::impl { class ConnectionHandler { public: using OnDisconnectHook = std::function; - - struct StringHash { - using hash_type = std::hash; - using is_transparent = void; - - std::size_t - operator()(char const* str) const; - std::size_t - operator()(std::string_view str) const; - std::size_t - operator()(std::string const& str) const; - }; - - using TargetToHandlerMap = std::unordered_map>; + using TargetToHandlerMap = std::unordered_map>; private: util::Logger log_{"WebServer"}; diff --git a/tests/common/web/dosguard/DOSGuardMock.hpp b/tests/common/web/dosguard/DOSGuardMock.hpp index a465d8b7..89843901 100644 --- a/tests/common/web/dosguard/DOSGuardMock.hpp +++ b/tests/common/web/dosguard/DOSGuardMock.hpp @@ -21,6 +21,7 @@ #include "web/dosguard/DOSGuardInterface.hpp" +#include #include #include @@ -33,7 +34,7 @@ struct DOSGuardMockImpl : web::dosguard::DOSGuardInterface { MOCK_METHOD(void, increment, (std::string const& ip), (noexcept, override)); MOCK_METHOD(void, decrement, (std::string const& ip), (noexcept, override)); MOCK_METHOD(bool, add, (std::string const& ip, uint32_t size), (noexcept, override)); - MOCK_METHOD(bool, request, (std::string const& ip), (noexcept, override)); + MOCK_METHOD(bool, request, (std::string const& ip, boost::json::object const& request), (override)); MOCK_METHOD(void, clear, (), (noexcept, override)); }; diff --git a/tests/common/web/interface/ConnectionBaseMock.hpp b/tests/common/web/interface/ConnectionBaseMock.hpp index e3885bc5..9a2f0c13 100644 --- a/tests/common/web/interface/ConnectionBaseMock.hpp +++ b/tests/common/web/interface/ConnectionBaseMock.hpp @@ -40,6 +40,7 @@ struct ConnectionBaseMock : web::ConnectionBase { (util::TagDecoratorFactory const& factory), (override) ); + MOCK_METHOD(void, sendSlowDown, (std::string const&), (override)); }; using ConnectionBaseStrictMockPtr = std::shared_ptr>; diff --git a/tests/unit/CMakeLists.txt b/tests/unit/CMakeLists.txt index 5985229a..3484b356 100644 --- a/tests/unit/CMakeLists.txt +++ b/tests/unit/CMakeLists.txt @@ -90,6 +90,7 @@ target_sources( rpc/common/CheckersTests.cpp rpc/common/SpecsTests.cpp rpc/common/TypesTests.cpp + rpc/common/impl/HandlerProviderTests.cpp rpc/handlers/AccountChannelsTests.cpp rpc/handlers/AccountCurrenciesTests.cpp rpc/handlers/AccountInfoTests.cpp @@ -146,6 +147,7 @@ target_sources( util/CoroutineGroupTests.cpp util/LedgerUtilsTests.cpp util/StrandedPriorityQueueTests.cpp + util/StringHashTests.cpp # Prometheus support util/prometheus/BoolTests.cpp util/prometheus/CounterTests.cpp @@ -177,6 +179,7 @@ target_sources( web/AdminVerificationTests.cpp web/dosguard/DOSGuardTests.cpp web/dosguard/IntervalSweepHandlerTests.cpp + web/dosguard/WeightsTests.cpp web/dosguard/WhitelistHandlerTests.cpp web/impl/ErrorHandlingTests.cpp web/ng/ResponseTests.cpp @@ -196,6 +199,7 @@ target_sources( util/newconfig/ArrayTests.cpp util/newconfig/ArrayViewTests.cpp util/newconfig/ClioConfigDefinitionTests.cpp + util/newconfig/ConfigConstraintsTests.cpp util/newconfig/ConfigValueTests.cpp util/newconfig/ObjectViewTests.cpp util/newconfig/ConfigFileJsonTests.cpp diff --git a/tests/unit/app/WebHandlersTests.cpp b/tests/unit/app/WebHandlersTests.cpp index 030cc846..b90345b2 100644 --- a/tests/unit/app/WebHandlersTests.cpp +++ b/tests/unit/app/WebHandlersTests.cpp @@ -165,75 +165,13 @@ struct RequestHandlerTest : SyncAsioContextTest, WebHandlersTest { testing::StrictMock rpcHandler; StrictMockConnection connectionMock{ip, boost::beast::flat_buffer{}, tagFactory}; - RequestHandler requestHandler{adminVerifier, rpcHandler, dosGuardMock}; + RequestHandler requestHandler{adminVerifier, rpcHandler}; }; -TEST_F(RequestHandlerTest, DosguardRateLimited_Http) -{ - web::ng::Request const request{http::request{http::verb::get, "/", 11}}; - - EXPECT_CALL(dosGuardMock, request(ip)).WillOnce(testing::Return(false)); - - runSpawn([&](boost::asio::yield_context yield) { - auto response = requestHandler(request, connectionMock, nullptr, yield); - auto const httpResponse = std::move(response).intoHttpResponse(); - - EXPECT_EQ(httpResponse.result(), boost::beast::http::status::service_unavailable); - - auto const body = boost::json::parse(httpResponse.body()).as_object(); - EXPECT_EQ(body.at("error").as_string(), "slowDown"); - EXPECT_EQ(body.at("error_code").as_int64(), rpc::RippledError::rpcSLOW_DOWN); - EXPECT_EQ(body.at("status").as_string(), "error"); - EXPECT_FALSE(body.contains("id")); - EXPECT_FALSE(body.contains("request")); - }); -} - -TEST_F(RequestHandlerTest, DosguardRateLimited_Ws) -{ - auto const requestMessage = R"json({"some": "request", "id": "some id"})json"; - web::ng::Request::HttpHeaders const headers{}; - web::ng::Request const request{requestMessage, headers}; - - EXPECT_CALL(dosGuardMock, request(ip)).WillOnce(testing::Return(false)); - - runSpawn([&](boost::asio::yield_context yield) { - auto const response = requestHandler(request, connectionMock, nullptr, yield); - auto const message = boost::json::parse(response.message()).as_object(); - - EXPECT_EQ(message.at("error").as_string(), "slowDown"); - EXPECT_EQ(message.at("error_code").as_int64(), rpc::RippledError::rpcSLOW_DOWN); - EXPECT_EQ(message.at("status").as_string(), "error"); - EXPECT_EQ(message.at("id").as_string(), "some id"); - EXPECT_EQ(message.at("request").as_string(), requestMessage); - }); -} - -TEST_F(RequestHandlerTest, DosguardRateLimited_Ws_ErrorParsing) -{ - auto const requestMessage = R"json(some request "id": "some id")json"; - web::ng::Request::HttpHeaders const headers{}; - web::ng::Request const request{requestMessage, headers}; - - EXPECT_CALL(dosGuardMock, request(ip)).WillOnce(testing::Return(false)); - - runSpawn([&](boost::asio::yield_context yield) { - auto const response = requestHandler(request, connectionMock, nullptr, yield); - auto const message = boost::json::parse(response.message()).as_object(); - - EXPECT_EQ(message.at("error").as_string(), "slowDown"); - EXPECT_EQ(message.at("error_code").as_int64(), rpc::RippledError::rpcSLOW_DOWN); - EXPECT_EQ(message.at("status").as_string(), "error"); - EXPECT_FALSE(message.contains("id")); - EXPECT_EQ(message.at("request").as_string(), requestMessage); - }); -} - TEST_F(RequestHandlerTest, RpcHandlerThrows) { web::ng::Request const request{http::request{http::verb::get, "/", 11}}; - EXPECT_CALL(dosGuardMock, request(ip)).WillOnce(testing::Return(true)); EXPECT_CALL(*adminVerifier, isAdmin).WillOnce(testing::Return(true)); EXPECT_CALL(rpcHandler, call).WillOnce(testing::Throw(std::runtime_error{"some error"})); @@ -257,10 +195,8 @@ TEST_F(RequestHandlerTest, NoErrors) web::ng::Response const response{http::status::ok, "some response", request}; auto const httpResponse = web::ng::Response{response}.intoHttpResponse(); - EXPECT_CALL(dosGuardMock, request(ip)).WillOnce(testing::Return(true)); EXPECT_CALL(*adminVerifier, isAdmin).WillOnce(testing::Return(true)); EXPECT_CALL(rpcHandler, call).WillOnce(testing::Return(response)); - EXPECT_CALL(dosGuardMock, add(ip, testing::_)).WillOnce(testing::Return(true)); runSpawn([&](boost::asio::yield_context yield) { auto actualResponse = requestHandler(request, connectionMock, nullptr, yield); @@ -272,55 +208,3 @@ TEST_F(RequestHandlerTest, NoErrors) EXPECT_EQ(actualHttpResponse.version(), 11); }); } - -TEST_F(RequestHandlerTest, ResponseDosGuardWarning_ResponseHasWarnings) -{ - web::ng::Request const request{http::request{http::verb::get, "/", 11}}; - web::ng::Response const response{ - http::status::ok, R"json({"some":"response", "warnings":["some warning"]})json", request - }; - auto const httpResponse = web::ng::Response{response}.intoHttpResponse(); - - EXPECT_CALL(dosGuardMock, request(ip)).WillOnce(testing::Return(true)); - EXPECT_CALL(*adminVerifier, isAdmin).WillOnce(testing::Return(true)); - EXPECT_CALL(rpcHandler, call).WillOnce(testing::Return(response)); - EXPECT_CALL(dosGuardMock, add(ip, testing::_)).WillOnce(testing::Return(false)); - - runSpawn([&](boost::asio::yield_context yield) { - auto actualResponse = requestHandler(request, connectionMock, nullptr, yield); - - auto const actualHttpResponse = std::move(actualResponse).intoHttpResponse(); - - EXPECT_EQ(actualHttpResponse.result(), httpResponse.result()); - EXPECT_EQ(actualHttpResponse.version(), 11); - - auto actualBody = boost::json::parse(actualHttpResponse.body()).as_object(); - EXPECT_EQ(actualBody.at("some").as_string(), "response"); - EXPECT_EQ(actualBody.at("warnings").as_array().size(), 2); - }); -} - -TEST_F(RequestHandlerTest, ResponseDosGuardWarning_ResponseDoesntHaveWarnings) -{ - web::ng::Request const request{http::request{http::verb::get, "/", 11}}; - web::ng::Response const response{http::status::ok, R"json({"some":"response"})json", request}; - auto const httpResponse = web::ng::Response{response}.intoHttpResponse(); - - EXPECT_CALL(dosGuardMock, request(ip)).WillOnce(testing::Return(true)); - EXPECT_CALL(*adminVerifier, isAdmin).WillOnce(testing::Return(true)); - EXPECT_CALL(rpcHandler, call).WillOnce(testing::Return(response)); - EXPECT_CALL(dosGuardMock, add(ip, testing::_)).WillOnce(testing::Return(false)); - - runSpawn([&](boost::asio::yield_context yield) { - auto actualResponse = requestHandler(request, connectionMock, nullptr, yield); - - auto const actualHttpResponse = std::move(actualResponse).intoHttpResponse(); - - EXPECT_EQ(actualHttpResponse.result(), httpResponse.result()); - EXPECT_EQ(actualHttpResponse.version(), 11); - - auto actualBody = boost::json::parse(actualHttpResponse.body()).as_object(); - EXPECT_EQ(actualBody.at("some").as_string(), "response"); - EXPECT_EQ(actualBody.at("warnings").as_array().size(), 1); - }); -} diff --git a/tests/unit/rpc/RPCEngineTests.cpp b/tests/unit/rpc/RPCEngineTests.cpp index 039eef2c..34094326 100644 --- a/tests/unit/rpc/RPCEngineTests.cpp +++ b/tests/unit/rpc/RPCEngineTests.cpp @@ -41,6 +41,7 @@ #include "util/newconfig/Types.hpp" #include "web/Context.hpp" #include "web/dosguard/DOSGuard.hpp" +#include "web/dosguard/Weights.hpp" #include "web/dosguard/WhitelistHandler.hpp" #include @@ -100,7 +101,8 @@ struct RPCEngineTest : util::prometheus::WithPrometheus, util::TagDecoratorFactory tagFactory{cfg}; WorkQueue queue = WorkQueue::makeWorkQueue(cfg); web::dosguard::WhitelistHandler whitelistHandler{cfg}; - web::dosguard::DOSGuard dosGuard{cfg, whitelistHandler}; + web::dosguard::Weights weights{1, {}}; + web::dosguard::DOSGuard dosGuard{cfg, whitelistHandler, weights}; std::shared_ptr handlerProvider = std::make_shared(); }; diff --git a/tests/unit/rpc/common/impl/HandlerProviderTests.cpp b/tests/unit/rpc/common/impl/HandlerProviderTests.cpp new file mode 100644 index 00000000..4642503c --- /dev/null +++ b/tests/unit/rpc/common/impl/HandlerProviderTests.cpp @@ -0,0 +1,74 @@ +//------------------------------------------------------------------------------ +/* + This file is part of clio: https://github.com/XRPLF/clio + Copyright (c) 2025, the clio developers. + + Permission to use, copy, modify, and distribute this software for any + purpose with or without fee is hereby granted, provided that the above + copyright notice and this permission notice appear in all copies. + + THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES + WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF + MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR + ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES + WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. +*/ +//============================================================================== + +#include "rpc/Counters.hpp" +#include "rpc/RPCCenter.hpp" +#include "rpc/WorkQueue.hpp" +#include "rpc/common/APIVersion.hpp" +#include "rpc/common/impl/HandlerProvider.hpp" +#include "util/MockAmendmentCenter.hpp" +#include "util/MockBackendTestFixture.hpp" +#include "util/MockETLService.hpp" +#include "util/MockLoadBalancer.hpp" +#include "util/MockPrometheus.hpp" +#include "util/MockSubscriptionManager.hpp" +#include "util/newconfig/ConfigDefinition.hpp" +#include "util/newconfig/ConfigValue.hpp" +#include "util/newconfig/Types.hpp" + +#include +#include + +#include + +using namespace rpc; + +struct ProductionHandlerProviderTest : util::prometheus::WithPrometheus, MockBackendTestStrict { + util::config::ClioConfigDefinition config{ + {"api_version.default", + util::config::ConfigValue{util::config::ConfigType::Integer}.defaultValue(rpc::kAPI_VERSION_DEFAULT)}, + {"api_version.min", + util::config::ConfigValue{util::config::ConfigType::Integer}.defaultValue(rpc::kAPI_VERSION_MIN)}, + {"api_version.max", + util::config::ConfigValue{util::config::ConfigType::Integer}.defaultValue(rpc::kAPI_VERSION_MAX)}, + }; + StrictMockSubscriptionManagerSharedPtr subscriptionManagerMock; + std::shared_ptr> loadBalancerMock; + std::shared_ptr> etlServiceMock; + StrictMockAmendmentCenterSharedPtr mockAmendmentCenterPtr; + WorkQueue workQueue{1}; + Counters counters{workQueue}; + + impl::ProductionHandlerProvider handlerProvider{ + config, + backend_, + subscriptionManagerMock, + loadBalancerMock, + etlServiceMock, + mockAmendmentCenterPtr, + counters + }; +}; + +TEST_F(ProductionHandlerProviderTest, HandlersListIsComplete) +{ + auto const handlerNames = handlerProvider.handlerNames(); + for (auto const& name : handlerNames) + EXPECT_TRUE(RPCCenter::isHandled(name)); +} diff --git a/tests/unit/util/StringHashTests.cpp b/tests/unit/util/StringHashTests.cpp new file mode 100644 index 00000000..af16efc8 --- /dev/null +++ b/tests/unit/util/StringHashTests.cpp @@ -0,0 +1,70 @@ +//------------------------------------------------------------------------------ +/* + This file is part of clio: https://github.com/XRPLF/clio + Copyright (c) 2025, the clio developers. + + Permission to use, copy, modify, and distribute this software for any + purpose with or without fee is hereby granted, provided that the above + copyright notice and this permission notice appear in all copies. + + THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES + WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF + MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR + ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES + WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. +*/ +//============================================================================== + +#include "util/StringHash.hpp" + +#include + +#include +#include +#include +#include + +using namespace util; + +TEST(StringHashTest, HashesConsistently) +{ + StringHash hasher; + + std::string const stdString = "test string"; + std::string_view const strView = "test string"; + char const* cString = "test string"; + + EXPECT_EQ(hasher(stdString), hasher(strView)); + EXPECT_EQ(hasher(stdString), hasher(cString)); + EXPECT_EQ(hasher(strView), hasher(cString)); +} + +TEST(StringHashTest, TransparentLookup) +{ + std::unordered_set> stringSet{"hello world"}; + + std::string const stdString = "hello world"; + std::string_view const strView = "hello world"; + char const* cString = "hello world"; + + EXPECT_TRUE(stringSet.contains(stdString)); + EXPECT_TRUE(stringSet.contains(strView)); + EXPECT_TRUE(stringSet.contains(cString)); + + EXPECT_FALSE(stringSet.contains("goodbye world")); +} + +TEST(StringHashTest, EmptyStrings) +{ + StringHash hasher; + + std::string const emptyStdString; + std::string_view const emptyStrView; + char const* emptyCString = ""; + + EXPECT_EQ(hasher(emptyStdString), hasher(emptyStrView)); + EXPECT_EQ(hasher(emptyStdString), hasher(emptyCString)); + EXPECT_EQ(hasher(emptyStrView), hasher(emptyCString)); +} diff --git a/tests/unit/util/newconfig/ConfigConstraintsTests.cpp b/tests/unit/util/newconfig/ConfigConstraintsTests.cpp new file mode 100644 index 00000000..911ac29d --- /dev/null +++ b/tests/unit/util/newconfig/ConfigConstraintsTests.cpp @@ -0,0 +1,50 @@ +//------------------------------------------------------------------------------ +/* + This file is part of clio: https://github.com/XRPLF/clio + Copyright (c) 2025, the clio developers. + + Permission to use, copy, modify, and distribute this software for any + purpose with or without fee is hereby granted, provided that the above + copyright notice and this permission notice appear in all copies. + + THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES + WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF + MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR + ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES + WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. +*/ +//============================================================================== + +#include "util/newconfig/ConfigConstraints.hpp" +#include "util/newconfig/Types.hpp" + +#include + +using namespace util::config; + +struct RpcNameConstraintTest : testing::Test {}; + +TEST_F(RpcNameConstraintTest, WrongType) +{ + Value value{1}; + auto const maybeError = gRpcNameConstraint.checkConstraint(value); + ASSERT_TRUE(maybeError.has_value()); + EXPECT_EQ(maybeError->error, "RPC command name must be a string"); +} + +TEST_F(RpcNameConstraintTest, WrongValue) +{ + Value value{"non_existing_rpc"}; + auto const maybeError = gRpcNameConstraint.checkConstraint(value); + ASSERT_TRUE(maybeError.has_value()); + EXPECT_EQ(maybeError->error, "Invalid RPC command name"); +} + +TEST_F(RpcNameConstraintTest, CorrectValue) +{ + Value value{"server_info"}; + auto const maybeError = gRpcNameConstraint.checkConstraint(value); + ASSERT_FALSE(maybeError.has_value()); +} diff --git a/tests/unit/web/RPCServerHandlerTests.cpp b/tests/unit/web/RPCServerHandlerTests.cpp index 5f2a11c7..f85bbf82 100644 --- a/tests/unit/web/RPCServerHandlerTests.cpp +++ b/tests/unit/web/RPCServerHandlerTests.cpp @@ -31,6 +31,7 @@ #include "util/newconfig/Types.hpp" #include "web/RPCServerHandler.hpp" #include "web/SubscriptionContextInterface.hpp" +#include "web/dosguard/DOSGuardMock.hpp" #include "web/interface/ConnectionBase.hpp" #include @@ -39,6 +40,7 @@ #include #include +#include #include #include #include @@ -58,6 +60,7 @@ constexpr auto kMAX_SEQ = 30; struct MockWsBase : public web::ConnectionBase { std::string message; boost::beast::http::status lastStatus = boost::beast::http::status::unknown; + size_t slowDownCallsCounter{0}; void send(std::shared_ptr msgType) override @@ -74,6 +77,12 @@ struct MockWsBase : public web::ConnectionBase { lastStatus = status; } + void + sendSlowDown(std::string const&) override + { + ++slowDownCallsCounter; + } + SubscriptionContextPtr makeSubscriptionContext(util::TagDecoratorFactory const&) override { @@ -94,9 +103,10 @@ struct WebRPCServerHandlerTest : util::prometheus::WithPrometheus, MockBackendTe }; std::shared_ptr rpcEngine = std::make_shared(); std::shared_ptr etl = std::make_shared(); + DOSGuardStrictMock dosguard; std::shared_ptr tagFactory = std::make_shared(cfg); std::shared_ptr> handler = - std::make_shared>(cfg, backend_, rpcEngine, etl); + std::make_shared>(cfg, backend_, rpcEngine, etl, dosguard); std::shared_ptr session = std::make_shared(*tagFactory); }; @@ -121,6 +131,11 @@ TEST_F(WebRPCServerHandlerTest, HTTPDefaultPath) } ] })"; + + EXPECT_CALL(dosguard, isOk(session->clientIp)).WillOnce(testing::Return(true)); + EXPECT_CALL(dosguard, request(session->clientIp, boost::json::parse(kREQUEST).as_object())) + .WillOnce(testing::Return(true)); + EXPECT_CALL(*rpcEngine, buildResponse(testing::_)) .WillOnce(testing::Return(rpc::Result{boost::json::parse(kRESULT).as_object()})); EXPECT_CALL(*rpcEngine, notifyComplete("server_info", testing::_)).Times(1); @@ -131,6 +146,33 @@ TEST_F(WebRPCServerHandlerTest, HTTPDefaultPath) EXPECT_EQ(boost::json::parse(session->message), boost::json::parse(kRESPONSE)); } +TEST_F(WebRPCServerHandlerTest, HTTPRejectedByDosguard) +{ + static constexpr auto kREQUEST = R"({ + "method": "server_info", + "params": [{}] + })"; + + EXPECT_CALL(dosguard, isOk(session->clientIp)).WillOnce(testing::Return(false)); + + (*handler)(kREQUEST, session); + EXPECT_EQ(session->slowDownCallsCounter, 1); +} + +TEST_F(WebRPCServerHandlerTest, HTTPRejectedByDosguardAfterParsing) +{ + static constexpr auto kREQUEST = R"({ + "method": "server_info", + "params": [{}] + })"; + + EXPECT_CALL(dosguard, isOk(session->clientIp)).WillOnce(testing::Return(true)); + EXPECT_CALL(dosguard, request(session->clientIp, testing::_)).WillOnce(testing::Return(false)); + + (*handler)(kREQUEST, session); + EXPECT_EQ(session->slowDownCallsCounter, 1); +} + TEST_F(WebRPCServerHandlerTest, WsNormalPath) { session->upgraded = true; @@ -156,6 +198,10 @@ TEST_F(WebRPCServerHandlerTest, WsNormalPath) } ] })"; + EXPECT_CALL(dosguard, isOk(session->clientIp)).WillOnce(testing::Return(true)); + EXPECT_CALL(dosguard, request(session->clientIp, boost::json::parse(kREQUEST).as_object())) + .WillOnce(testing::Return(true)); + EXPECT_CALL(*rpcEngine, buildResponse(testing::_)) .WillOnce(testing::Return(rpc::Result{boost::json::parse(kRESULT).as_object()})); EXPECT_CALL(*rpcEngine, notifyComplete("server_info", testing::_)).Times(1); @@ -166,6 +212,38 @@ TEST_F(WebRPCServerHandlerTest, WsNormalPath) EXPECT_EQ(boost::json::parse(session->message), boost::json::parse(kRESPONSE)); } +TEST_F(WebRPCServerHandlerTest, WsRejectedByDosguard) +{ + session->upgraded = true; + static constexpr auto kREQUEST = R"({ + "command": "server_info", + "id": 99, + "api_version": 2 + })"; + + EXPECT_CALL(dosguard, isOk(session->clientIp)).WillOnce(testing::Return(false)); + + (*handler)(kREQUEST, session); + EXPECT_EQ(session->slowDownCallsCounter, 1); +} + +TEST_F(WebRPCServerHandlerTest, WsRejectedByDosguardAfterParsing) +{ + session->upgraded = true; + static constexpr auto kREQUEST = R"({ + "command": "server_info", + "id": 99, + "api_version": 2 + })"; + + EXPECT_CALL(dosguard, isOk(session->clientIp)).WillOnce(testing::Return(true)); + EXPECT_CALL(dosguard, request(session->clientIp, boost::json::parse(kREQUEST).as_object())) + .WillOnce(testing::Return(false)); + + (*handler)(kREQUEST, session); + EXPECT_EQ(session->slowDownCallsCounter, 1); +} + TEST_F(WebRPCServerHandlerTest, HTTPForwardedPath) { static constexpr auto kREQUEST = R"({ @@ -195,6 +273,11 @@ TEST_F(WebRPCServerHandlerTest, HTTPForwardedPath) } ] })"; + + EXPECT_CALL(dosguard, isOk(session->clientIp)).WillOnce(testing::Return(true)); + EXPECT_CALL(dosguard, request(session->clientIp, boost::json::parse(kREQUEST).as_object())) + .WillOnce(testing::Return(true)); + EXPECT_CALL(*rpcEngine, buildResponse(testing::_)) .WillOnce(testing::Return(rpc::Result{boost::json::parse(kRESULT).as_object()})); EXPECT_CALL(*rpcEngine, notifyComplete("server_info", testing::_)).Times(1); @@ -239,6 +322,11 @@ TEST_F(WebRPCServerHandlerTest, HTTPForwardedErrorPath) } ] })"; + + EXPECT_CALL(dosguard, isOk(session->clientIp)).WillOnce(testing::Return(true)); + EXPECT_CALL(dosguard, request(session->clientIp, boost::json::parse(kREQUEST).as_object())) + .WillOnce(testing::Return(true)); + EXPECT_CALL(*rpcEngine, buildResponse(testing::_)) .WillOnce(testing::Return(rpc::Result{boost::json::parse(kRESULT).as_object()})); EXPECT_CALL(*rpcEngine, notifyComplete("server_info", testing::_)).Times(1); @@ -281,6 +369,11 @@ TEST_F(WebRPCServerHandlerTest, WsForwardedPath) } ] })"; + + EXPECT_CALL(dosguard, isOk(session->clientIp)).WillOnce(testing::Return(true)); + EXPECT_CALL(dosguard, request(session->clientIp, boost::json::parse(kREQUEST).as_object())) + .WillOnce(testing::Return(true)); + EXPECT_CALL(*rpcEngine, buildResponse(testing::_)) .WillOnce(testing::Return(rpc::Result{boost::json::parse(kRESULT).as_object()})); EXPECT_CALL(*rpcEngine, notifyComplete("server_info", testing::_)).Times(1); @@ -326,6 +419,11 @@ TEST_F(WebRPCServerHandlerTest, WsForwardedErrorPath) } ] })"; + + EXPECT_CALL(dosguard, isOk(session->clientIp)).WillOnce(testing::Return(true)); + EXPECT_CALL(dosguard, request(session->clientIp, boost::json::parse(kREQUEST).as_object())) + .WillOnce(testing::Return(true)); + EXPECT_CALL(*rpcEngine, buildResponse(testing::_)) .WillOnce(testing::Return(rpc::Result{boost::json::parse(kRESULT).as_object()})); @@ -373,6 +471,11 @@ TEST_F(WebRPCServerHandlerTest, HTTPErrorPath) } ] })"; + + EXPECT_CALL(dosguard, isOk(session->clientIp)).WillOnce(testing::Return(true)); + EXPECT_CALL(dosguard, request(session->clientIp, boost::json::parse(kREQUEST_JSON).as_object())) + .WillOnce(testing::Return(true)); + EXPECT_CALL(*rpcEngine, buildResponse(testing::_)) .WillOnce(testing::Return(rpc::Result{rpc::Status{rpc::RippledError::rpcINVALID_PARAMS, "ledgerIndexMalformed"}} )); @@ -416,6 +519,11 @@ TEST_F(WebRPCServerHandlerTest, WsErrorPath) "id": "123", "api_version": 2 })"; + + EXPECT_CALL(dosguard, isOk(session->clientIp)).WillOnce(testing::Return(true)); + EXPECT_CALL(dosguard, request(session->clientIp, boost::json::parse(kREQUEST_JSON).as_object())) + .WillOnce(testing::Return(true)); + EXPECT_CALL(*rpcEngine, buildResponse(testing::_)) .WillOnce(testing::Return(rpc::Result{rpc::Status{rpc::RippledError::rpcINVALID_PARAMS, "ledgerIndexMalformed"}} )); @@ -447,6 +555,10 @@ TEST_F(WebRPCServerHandlerTest, HTTPNotReady) } })"; + EXPECT_CALL(dosguard, isOk(session->clientIp)).WillOnce(testing::Return(true)); + EXPECT_CALL(dosguard, request(session->clientIp, boost::json::parse(kREQUEST).as_object())) + .WillOnce(testing::Return(true)); + EXPECT_CALL(*rpcEngine, notifyNotReady).Times(1); (*handler)(kREQUEST, session); @@ -475,6 +587,10 @@ TEST_F(WebRPCServerHandlerTest, WsNotReady) } })"; + EXPECT_CALL(dosguard, isOk(session->clientIp)).WillOnce(testing::Return(true)); + EXPECT_CALL(dosguard, request(session->clientIp, boost::json::parse(kREQUEST).as_object())) + .WillOnce(testing::Return(true)); + EXPECT_CALL(*rpcEngine, notifyNotReady).Times(1); (*handler)(kREQUEST, session); @@ -501,6 +617,9 @@ TEST_F(WebRPCServerHandlerTest, HTTPBadSyntaxWhenRequestSubscribe) } })"; + EXPECT_CALL(dosguard, isOk(session->clientIp)).WillOnce(testing::Return(true)); + EXPECT_CALL(dosguard, request(session->clientIp, testing::_)).WillOnce(testing::Return(true)); + EXPECT_CALL(*rpcEngine, notifyBadSyntax).Times(1); (*handler)(kREQUEST, session); @@ -515,6 +634,9 @@ TEST_F(WebRPCServerHandlerTest, HTTPMissingCommand) static constexpr auto kRESPONSE = "Null method"; + EXPECT_CALL(dosguard, isOk(session->clientIp)).WillOnce(testing::Return(true)); + EXPECT_CALL(dosguard, request(session->clientIp, testing::_)).WillOnce(testing::Return(true)); + EXPECT_CALL(*rpcEngine, notifyBadSyntax).Times(1); (*handler)(kREQUEST, session); @@ -530,6 +652,9 @@ TEST_F(WebRPCServerHandlerTest, HTTPCommandNotString) static constexpr auto kRESPONSE = "method is not string"; + EXPECT_CALL(dosguard, isOk(session->clientIp)).WillOnce(testing::Return(true)); + EXPECT_CALL(dosguard, request(session->clientIp, testing::_)).WillOnce(testing::Return(true)); + EXPECT_CALL(*rpcEngine, notifyBadSyntax).Times(1); (*handler)(kREQUEST, session); @@ -545,6 +670,9 @@ TEST_F(WebRPCServerHandlerTest, HTTPCommandIsEmpty) static constexpr auto kRESPONSE = "method is empty"; + EXPECT_CALL(dosguard, isOk(session->clientIp)).WillOnce(testing::Return(true)); + EXPECT_CALL(dosguard, request(session->clientIp, testing::_)).WillOnce(testing::Return(true)); + EXPECT_CALL(*rpcEngine, notifyBadSyntax).Times(1); (*handler)(kREQUEST, session); @@ -575,6 +703,10 @@ TEST_F(WebRPCServerHandlerTest, WsMissingCommand) } })"; + EXPECT_CALL(dosguard, isOk(session->clientIp)).WillOnce(testing::Return(true)); + EXPECT_CALL(dosguard, request(session->clientIp, boost::json::parse(kREQUEST).as_object())) + .WillOnce(testing::Return(true)); + EXPECT_CALL(*rpcEngine, notifyBadSyntax).Times(1); (*handler)(kREQUEST, session); @@ -592,6 +724,9 @@ TEST_F(WebRPCServerHandlerTest, HTTPParamsUnparsableNotArray) "params": "wrong" })"; + EXPECT_CALL(dosguard, isOk(session->clientIp)).WillOnce(testing::Return(true)); + EXPECT_CALL(dosguard, request(session->clientIp, testing::_)).WillOnce(testing::Return(true)); + EXPECT_CALL(*rpcEngine, notifyBadSyntax).Times(1); (*handler)(kREQUEST_JSON, session); @@ -610,6 +745,9 @@ TEST_F(WebRPCServerHandlerTest, HTTPParamsUnparsableArrayWithDigit) "params": [1] })"; + EXPECT_CALL(dosguard, isOk(session->clientIp)).WillOnce(testing::Return(true)); + EXPECT_CALL(dosguard, request(session->clientIp, testing::_)).WillOnce(testing::Return(true)); + EXPECT_CALL(*rpcEngine, notifyBadSyntax).Times(1); (*handler)(kREQUEST_JSON, session); @@ -640,6 +778,10 @@ TEST_F(WebRPCServerHandlerTest, HTTPInternalError) "params": [{}] })"; + EXPECT_CALL(dosguard, isOk(session->clientIp)).WillOnce(testing::Return(true)); + EXPECT_CALL(dosguard, request(session->clientIp, boost::json::parse(kREQUEST_JSON).as_object())) + .WillOnce(testing::Return(true)); + EXPECT_CALL(*rpcEngine, notifyInternalError).Times(1); EXPECT_CALL(*rpcEngine, buildResponse(testing::_)).Times(1).WillOnce(testing::Throw(std::runtime_error("MyError"))); @@ -671,6 +813,10 @@ TEST_F(WebRPCServerHandlerTest, WsInternalError) "id": "123" })"; + EXPECT_CALL(dosguard, isOk(session->clientIp)).WillOnce(testing::Return(true)); + EXPECT_CALL(dosguard, request(session->clientIp, boost::json::parse(kREQUEST_JSON).as_object())) + .WillOnce(testing::Return(true)); + EXPECT_CALL(*rpcEngine, notifyInternalError).Times(1); EXPECT_CALL(*rpcEngine, buildResponse(testing::_)).Times(1).WillOnce(testing::Throw(std::runtime_error("MyError"))); @@ -703,6 +849,11 @@ TEST_F(WebRPCServerHandlerTest, HTTPOutDated) } ] })"; + + EXPECT_CALL(dosguard, isOk(session->clientIp)).WillOnce(testing::Return(true)); + EXPECT_CALL(dosguard, request(session->clientIp, boost::json::parse(kREQUEST).as_object())) + .WillOnce(testing::Return(true)); + EXPECT_CALL(*rpcEngine, buildResponse(testing::_)) .WillOnce(testing::Return(rpc::Result{boost::json::parse(kRESULT).as_object()})); EXPECT_CALL(*rpcEngine, notifyComplete("server_info", testing::_)).Times(1); @@ -741,6 +892,11 @@ TEST_F(WebRPCServerHandlerTest, WsOutdated) } ] })"; + + EXPECT_CALL(dosguard, isOk(session->clientIp)).WillOnce(testing::Return(true)); + EXPECT_CALL(dosguard, request(session->clientIp, boost::json::parse(kREQUEST).as_object())) + .WillOnce(testing::Return(true)); + EXPECT_CALL(*rpcEngine, buildResponse(testing::_)) .WillOnce(testing::Return(rpc::Result{boost::json::parse(kRESULT).as_object()})); EXPECT_CALL(*rpcEngine, notifyComplete("server_info", testing::_)).Times(1); @@ -756,7 +912,7 @@ TEST_F(WebRPCServerHandlerTest, WsTooBusy) session->upgraded = true; auto localRpcEngine = std::make_shared(); - auto localHandler = std::make_shared>(cfg, backend_, localRpcEngine, etl); + auto localHandler = std::make_shared>(cfg, backend_, localRpcEngine, etl, dosguard); static constexpr auto kREQUEST = R"({ "command": "server_info", "id": 99 @@ -773,6 +929,10 @@ TEST_F(WebRPCServerHandlerTest, WsTooBusy) "type": "response" })"; + EXPECT_CALL(dosguard, isOk(session->clientIp)).WillOnce(testing::Return(true)); + EXPECT_CALL(dosguard, request(session->clientIp, boost::json::parse(kREQUEST).as_object())) + .WillOnce(testing::Return(true)); + EXPECT_CALL(*localRpcEngine, notifyTooBusy).Times(1); EXPECT_CALL(*localRpcEngine, post).WillOnce(testing::Return(false)); @@ -783,7 +943,7 @@ TEST_F(WebRPCServerHandlerTest, WsTooBusy) TEST_F(WebRPCServerHandlerTest, HTTPTooBusy) { auto localRpcEngine = std::make_shared(); - auto localHandler = std::make_shared>(cfg, backend_, localRpcEngine, etl); + auto localHandler = std::make_shared>(cfg, backend_, localRpcEngine, etl, dosguard); static constexpr auto kREQUEST = R"({ "method": "server_info", "params": [{}] @@ -800,6 +960,10 @@ TEST_F(WebRPCServerHandlerTest, HTTPTooBusy) "type": "response" })"; + EXPECT_CALL(dosguard, isOk(session->clientIp)).WillOnce(testing::Return(true)); + EXPECT_CALL(dosguard, request(session->clientIp, boost::json::parse(kREQUEST).as_object())) + .WillOnce(testing::Return(true)); + EXPECT_CALL(*localRpcEngine, notifyTooBusy).Times(1); EXPECT_CALL(*localRpcEngine, post).WillOnce(testing::Return(false)); @@ -812,6 +976,8 @@ TEST_F(WebRPCServerHandlerTest, HTTPRequestNotJson) static constexpr auto kREQUEST = "not json"; static constexpr auto kRESPONSE_PREFIX = "Unable to parse JSON from the request"; + EXPECT_CALL(dosguard, isOk(session->clientIp)).WillOnce(testing::Return(true)); + EXPECT_CALL(*rpcEngine, notifyBadSyntax).Times(1); (*handler)(kREQUEST, session); @@ -832,6 +998,8 @@ TEST_F(WebRPCServerHandlerTest, WsRequestNotJson) "type": "response" })"; + EXPECT_CALL(dosguard, isOk(session->clientIp)).WillOnce(testing::Return(true)); + EXPECT_CALL(*rpcEngine, notifyBadSyntax).Times(1); (*handler)(kREQUEST, session); @@ -888,6 +1056,10 @@ TEST_P(WebRPCServerHandlerInvalidAPIVersionParamTest, HTTPInvalidAPIVersion) backend_->setRange(kMIN_SEQ, kMAX_SEQ); + EXPECT_CALL(dosguard, isOk(session->clientIp)).WillOnce(testing::Return(true)); + EXPECT_CALL(dosguard, request(session->clientIp, boost::json::parse(request).as_object())) + .WillOnce(testing::Return(true)); + EXPECT_CALL(*rpcEngine, notifyBadSyntax).Times(1); (*handler)(request, session); @@ -908,6 +1080,10 @@ TEST_P(WebRPCServerHandlerInvalidAPIVersionParamTest, WSInvalidAPIVersion) backend_->setRange(kMIN_SEQ, kMAX_SEQ); + EXPECT_CALL(dosguard, isOk(session->clientIp)).WillOnce(testing::Return(true)); + EXPECT_CALL(dosguard, request(session->clientIp, boost::json::parse(request).as_object())) + .WillOnce(testing::Return(true)); + EXPECT_CALL(*rpcEngine, notifyBadSyntax).Times(1); (*handler)(request, session); diff --git a/tests/unit/web/ServerTests.cpp b/tests/unit/web/ServerTests.cpp index 1c310b38..149a338a 100644 --- a/tests/unit/web/ServerTests.cpp +++ b/tests/unit/web/ServerTests.cpp @@ -35,6 +35,7 @@ #include "web/dosguard/DOSGuard.hpp" #include "web/dosguard/DOSGuardInterface.hpp" #include "web/dosguard/IntervalSweepHandler.hpp" +#include "web/dosguard/Weights.hpp" #include "web/dosguard/WhitelistHandler.hpp" #include "web/interface/ConnectionBase.hpp" @@ -162,12 +163,13 @@ struct WebServerTest : NoLoggerFixture { std::string const port = std::to_string(tests::util::generateFreePort()); ClioConfigDefinition cfg{getParseServerConfig(generateJSONWithDynamicPort(port))}; dosguard::WhitelistHandler whitelistHandler{cfg}; - dosguard::DOSGuard dosGuard{cfg, whitelistHandler}; + dosguard::Weights dosguardWeights{1, {}}; + dosguard::DOSGuard dosGuard{cfg, whitelistHandler, dosguardWeights}; dosguard::IntervalSweepHandler sweepHandler{cfg, ctxSync, dosGuard}; ClioConfigDefinition cfgOverload{getParseServerConfig(generateJSONDataOverload(port))}; dosguard::WhitelistHandler whitelistHandlerOverload{cfgOverload}; - dosguard::DOSGuard dosGuardOverload{cfgOverload, whitelistHandlerOverload}; + dosguard::DOSGuard dosGuardOverload{cfgOverload, whitelistHandlerOverload, dosguardWeights}; dosguard::IntervalSweepHandler sweepHandlerOverload{cfgOverload, ctxSync, dosGuardOverload}; // this ctx is for http server boost::asio::io_context ctx; @@ -344,41 +346,6 @@ TEST_F(WebServerTest, Wss) wsClient.disconnect(); } -TEST_F(WebServerTest, HttpRequestOverload) -{ - auto const e = std::make_shared(); - auto const server = makeServerSync(cfg, ctx, dosGuardOverload, e); - auto [status, res] = HttpSyncClient::post("localhost", port, R"({})"); - EXPECT_EQ(res, "{}"); - EXPECT_EQ(status, boost::beast::http::status::ok); - - std::tie(status, res) = HttpSyncClient::post("localhost", port, R"({})"); - EXPECT_EQ( - res, - R"({"error":"slowDown","error_code":10,"error_message":"You are placing too much load on the server.","status":"error","type":"response"})" - ); - EXPECT_EQ(status, boost::beast::http::status::service_unavailable); -} - -TEST_F(WebServerTest, WsRequestOverload) -{ - auto e = std::make_shared(); - auto const server = makeServerSync(cfg, ctx, dosGuardOverload, e); - WebSocketSyncClient wsClient; - wsClient.connect("localhost", port); - auto res = wsClient.syncPost(R"({})"); - wsClient.disconnect(); - EXPECT_EQ(res, "{}"); - WebSocketSyncClient wsClient2; - wsClient2.connect("localhost", port); - res = wsClient2.syncPost(R"({})"); - wsClient2.disconnect(); - EXPECT_EQ( - res, - R"({"error":"slowDown","error_code":10,"error_message":"You are placing too much load on the server.","status":"error","type":"response","request":{}})" - ); -} - TEST_F(WebServerTest, HttpPayloadOverload) { std::string const s100(100, 'a'); diff --git a/tests/unit/web/dosguard/DOSGuardTests.cpp b/tests/unit/web/dosguard/DOSGuardTests.cpp index f2ace608..05827506 100644 --- a/tests/unit/web/dosguard/DOSGuardTests.cpp +++ b/tests/unit/web/dosguard/DOSGuardTests.cpp @@ -23,11 +23,15 @@ #include "util/newconfig/ConfigValue.hpp" #include "util/newconfig/Types.hpp" #include "web/dosguard/DOSGuard.hpp" +#include "web/dosguard/WeightsInterface.hpp" #include "web/dosguard/WhitelistHandlerInterface.hpp" +#include #include #include +#include +#include #include using namespace testing; @@ -53,6 +57,9 @@ struct DOSGuardTest : NoLoggerFixture { struct MockWhitelistHandler : WhitelistHandlerInterface { MOCK_METHOD(bool, isWhiteListed, (std::string_view ip), (const)); }; + struct MockWeights : WeightsInterface { + MOCK_METHOD(size_t, requestWeight, (boost::json::object const& cmd), (const, override)); + }; ClioConfigDefinition cfg{ {{"dos_guard.max_fetches", ConfigValue{ConfigType::Integer}.defaultValue(100)}, @@ -61,7 +68,9 @@ struct DOSGuardTest : NoLoggerFixture { {"dos_guard.whitelist", Array{ConfigValue{ConfigType::String}}}} }; NiceMock whitelistHandler; - DOSGuard guard{cfg, whitelistHandler}; + StrictMock weightsMock; + DOSGuard guard{cfg, whitelistHandler, weightsMock}; + boost::json::object const request; }; TEST_F(DOSGuardTest, Whitelisting) @@ -110,11 +119,20 @@ TEST_F(DOSGuardTest, ClearFetchCountOnTimer) TEST_F(DOSGuardTest, RequestLimit) { - EXPECT_TRUE(guard.request(kIP)); - EXPECT_TRUE(guard.request(kIP)); - EXPECT_TRUE(guard.request(kIP)); + EXPECT_CALL(weightsMock, requestWeight(request)).WillOnce(Return(1)); + EXPECT_TRUE(guard.request(kIP, request)); + + EXPECT_CALL(weightsMock, requestWeight(request)).WillOnce(Return(1)); + EXPECT_TRUE(guard.request(kIP, request)); + + EXPECT_CALL(weightsMock, requestWeight(request)).WillOnce(Return(1)); + EXPECT_TRUE(guard.request(kIP, request)); + EXPECT_TRUE(guard.isOk(kIP)); - EXPECT_FALSE(guard.request(kIP)); + + EXPECT_CALL(weightsMock, requestWeight(request)).WillOnce(Return(1)); + EXPECT_FALSE(guard.request(kIP, request)); + EXPECT_FALSE(guard.isOk(kIP)); guard.clear(); EXPECT_TRUE(guard.isOk(kIP)); // can request again @@ -122,11 +140,20 @@ TEST_F(DOSGuardTest, RequestLimit) TEST_F(DOSGuardTest, RequestLimitOnTimer) { - EXPECT_TRUE(guard.request(kIP)); - EXPECT_TRUE(guard.request(kIP)); - EXPECT_TRUE(guard.request(kIP)); + EXPECT_CALL(weightsMock, requestWeight(request)).WillOnce(Return(1)); + EXPECT_TRUE(guard.request(kIP, request)); + + EXPECT_CALL(weightsMock, requestWeight(request)).WillOnce(Return(1)); + EXPECT_TRUE(guard.request(kIP, request)); + + EXPECT_CALL(weightsMock, requestWeight(request)).WillOnce(Return(1)); + EXPECT_TRUE(guard.request(kIP, request)); + EXPECT_TRUE(guard.isOk(kIP)); - EXPECT_FALSE(guard.request(kIP)); + + EXPECT_CALL(weightsMock, requestWeight(request)).WillOnce(Return(1)); + EXPECT_FALSE(guard.request(kIP, request)); + EXPECT_FALSE(guard.isOk(kIP)); guard.clear(); EXPECT_TRUE(guard.isOk(kIP)); // can request again diff --git a/tests/unit/web/dosguard/WeightsTests.cpp b/tests/unit/web/dosguard/WeightsTests.cpp new file mode 100644 index 00000000..2567c8af --- /dev/null +++ b/tests/unit/web/dosguard/WeightsTests.cpp @@ -0,0 +1,277 @@ +//------------------------------------------------------------------------------ +/* + This file is part of clio: https://github.com/XRPLF/clio + Copyright (c) 2025, the clio developers. + + Permission to use, copy, modify, and distribute this software for any + purpose with or without fee is hereby granted, provided that the above + copyright notice and this permission notice appear in all copies. + + THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES + WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF + MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR + ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES + WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. +*/ +//============================================================================== + +#include "util/newconfig/Array.hpp" +#include "util/newconfig/ConfigDefinition.hpp" +#include "util/newconfig/ConfigFileJson.hpp" +#include "util/newconfig/ConfigValue.hpp" +#include "util/newconfig/Types.hpp" +#include "web/dosguard/Weights.hpp" + +#include +#include +#include +#include + +#include +#include +#include +#include + +using namespace web::dosguard; + +struct TestParams { + std::string testName; + std::string requestJson; + size_t expectedWeight; +}; + +class WeightsTest : public ::testing::TestWithParam { +protected: + size_t const defaultWeight_{10}; + std::unordered_map weightsMap_{ + {"only_weight", {.weight = 20, .weightLedgerCurrent = std::nullopt, .weightLedgerValidated = std::nullopt}}, + {"with_current_weight", {.weight = 30, .weightLedgerCurrent = 35, .weightLedgerValidated = std::nullopt}}, + {"with_validated_weight", {.weight = 40, .weightLedgerCurrent = std::nullopt, .weightLedgerValidated = 45}}, + {"with_both_weights", {.weight = 50, .weightLedgerCurrent = 55, .weightLedgerValidated = 60}}, + }; + Weights weights_{defaultWeight_, weightsMap_}; +}; + +TEST_P(WeightsTest, RequestWeight) +{ + auto const& params = GetParam(); + auto request = boost::json::parse(params.requestJson).as_object(); + EXPECT_EQ(weights_.requestWeight(request), params.expectedWeight); +} + +INSTANTIATE_TEST_SUITE_P( + WeightsTests, + WeightsTest, + ::testing::Values( + TestParams{.testName = "EmptyObject", .requestJson = "{}", .expectedWeight = 10}, + TestParams{.testName = "NonStringMethod", .requestJson = R"json({"method": 123})json", .expectedWeight = 10}, + TestParams{.testName = "NonStringCommand", .requestJson = R"json({"command": 123})json", .expectedWeight = 10}, + + TestParams{ + .testName = "UnknownMethodName", + .requestJson = R"json({"method": "unknown_method"})json", + .expectedWeight = 10 + }, + TestParams{ + .testName = "UnknownCommandName", + .requestJson = R"json({"command": "unknown_command"})json", + .expectedWeight = 10 + }, + + TestParams{ + .testName = "OnlyWeight_NoLedgerIndex", + .requestJson = R"json({"method": "only_weight"})json", + .expectedWeight = 20 + }, + TestParams{ + .testName = "OnlyWeight_CurrentLedgerIndex", + .requestJson = R"json({"method": "only_weight", "ledger_index": "current"})json", + .expectedWeight = 20 + }, + TestParams{ + .testName = "OnlyWeight_ValidatedLedgerIndex", + .requestJson = R"json({"method": "only_weight", "ledger_index": "validated"})json", + .expectedWeight = 20 + }, + TestParams{ + .testName = "OnlyWeight_ClosedLedgerIndex", + .requestJson = R"json({"method": "only_weight", "ledger_index": "closed"})json", + .expectedWeight = 20 + }, + TestParams{ + .testName = "OnlyWeight_NumericLedgerIndex", + .requestJson = R"json({"method": "only_weight", "ledger_index": "123"})json", + .expectedWeight = 20 + }, + TestParams{ + .testName = "OnlyWeight_OtherStringLedgerIndex", + .requestJson = R"json({"method": "only_weight", "ledger_index": "some_string"})json", + .expectedWeight = 20 + }, + + // With Current Weight + TestParams{ + .testName = "WithCurrentWeight_NoLedgerIndex", + .requestJson = R"json({"method": "with_current_weight"})json", + .expectedWeight = 30 + }, + TestParams{ + .testName = "WithCurrentWeight_CurrentLedgerIndex", + .requestJson = R"json({"method": "with_current_weight", "ledger_index": "current"})json", + .expectedWeight = 35 + }, + TestParams{ + .testName = "WithCurrentWeight_ValidatedLedgerIndex", + .requestJson = R"json({"method": "with_current_weight", "ledger_index": "validated"})json", + .expectedWeight = 30 + }, + + // With Validated Weight + TestParams{ + .testName = "WithValidatedWeight_NoLedgerIndex", + .requestJson = R"json({"method": "with_validated_weight"})json", + .expectedWeight = 40 + }, + TestParams{ + .testName = "WithValidatedWeight_CurrentLedgerIndex", + .requestJson = R"json({"method": "with_validated_weight", "ledger_index": "current"})json", + .expectedWeight = 40 + }, + TestParams{ + .testName = "WithValidatedWeight_ValidatedLedgerIndex", + .requestJson = R"json({"method": "with_validated_weight", "ledger_index": "validated"})json", + .expectedWeight = 45 + }, + + // With Both Weights + TestParams{ + .testName = "WithBothWeights_NoLedgerIndex", + .requestJson = R"json({"method": "with_both_weights"})json", + .expectedWeight = 50 + }, + TestParams{ + .testName = "WithBothWeights_CurrentLedgerIndex", + .requestJson = R"json({"method": "with_both_weights", "ledger_index": "current"})json", + .expectedWeight = 55 + }, + TestParams{ + .testName = "WithBothWeights_ValidatedLedgerIndex", + .requestJson = R"json({"method": "with_both_weights", "ledger_index": "validated"})json", + .expectedWeight = 60 + }, + + // Using Command + TestParams{ + .testName = "UsingCommand_NoLedgerIndex", + .requestJson = R"json({"command": "with_both_weights"})json", + .expectedWeight = 50 + }, + TestParams{ + .testName = "UsingCommand_CurrentLedgerIndex", + .requestJson = R"json({"command": "with_both_weights", "ledger_index": "current"})json", + .expectedWeight = 55 + }, + TestParams{ + .testName = "UsingCommand_ValidatedLedgerIndex", + .requestJson = R"json({"command": "with_both_weights", "ledger_index": "validated"})json", + .expectedWeight = 60 + }, + + // With Params Array + TestParams{ + .testName = "WithParamsArray_CurrentLedgerIndex", + .requestJson = R"json({"method": "with_both_weights", "params": [{"ledger_index": "current"}]})json", + .expectedWeight = 55 + }, + TestParams{ + .testName = "WithParamsArray_ValidatedLedgerIndex", + .requestJson = R"json({"method": "with_both_weights", "params": [{"ledger_index": "validated"}]})json", + .expectedWeight = 60 + }, + TestParams{ + .testName = "WithParamsArray_WithCommand", + .requestJson = R"json({"command": "with_both_weights", "params": [{"ledger_index": "current"}]})json", + .expectedWeight = 55 + } + ), + [](::testing::TestParamInfo const& info) { return info.param.testName; } +); + +TEST(WeightsMakeTest, CreateFromConfig) +{ + util::config::ClioConfigDefinition mockConfig{ + {"dos_guard.__ng_default_weight", util::config::ConfigValue{util::config::ConfigType::Integer}.defaultValue(10) + }, + {"dos_guard.__ng_weights.[].method", + util::config::Array{util::config::ConfigValue{util::config::ConfigType::String}}}, + {"dos_guard.__ng_weights.[].weight", + util::config::Array{util::config::ConfigValue{util::config::ConfigType::Integer}}}, + {"dos_guard.__ng_weights.[].weight_ledger_current", + util::config::Array{util::config::ConfigValue{util::config::ConfigType::Integer}.optional()}}, + {"dos_guard.__ng_weights.[].weight_ledger_validated", + util::config::Array{util::config::ConfigValue{util::config::ConfigType::Integer}.optional()}} + }; + std::string const configStr = R"json( + { + "dos_guard": { + "__ng_default_weight": 15, + "__ng_weights": [ + { + "method": "method1", + "weight": 25, + "weight_ledger_current": 30 + }, + { + "method": "method2", + "weight": 35, + "weight_ledger_validated": 40 + }, + { + "method": "method3", + "weight": 45, + "weight_ledger_current": 50, + "weight_ledger_validated": 55 + } + ] + } + } + )json"; + + auto const configJson = boost::json::parse(configStr).as_object(); + + ASSERT_FALSE(mockConfig.parse(util::config::ConfigFileJson(configJson)).has_value()); + + Weights const weights = Weights::make(mockConfig); + + auto request = boost::json::parse(R"json({"method": "unknown_method"})json").as_object(); + EXPECT_EQ(weights.requestWeight(request), 15); + + request = boost::json::parse(R"json({"method": "method1"})json").as_object(); + EXPECT_EQ(weights.requestWeight(request), 25); + + request = boost::json::parse(R"json({"method": "method1", "ledger_index": "current"})json").as_object(); + EXPECT_EQ(weights.requestWeight(request), 30); + + request = boost::json::parse(R"json({"method": "method1", "ledger_index": "validated"})json").as_object(); + EXPECT_EQ(weights.requestWeight(request), 25); + + request = boost::json::parse(R"json({"method": "method2"})json").as_object(); + EXPECT_EQ(weights.requestWeight(request), 35); + + request = boost::json::parse(R"json({"method": "method2", "ledger_index": "current"})json").as_object(); + EXPECT_EQ(weights.requestWeight(request), 35); + + request = boost::json::parse(R"json({"method": "method2", "ledger_index": "validated"})json").as_object(); + EXPECT_EQ(weights.requestWeight(request), 40); + + request = boost::json::parse(R"json({"method": "method3"})json").as_object(); + EXPECT_EQ(weights.requestWeight(request), 45); + + request = boost::json::parse(R"json({"method": "method3", "ledger_index": "current"})json").as_object(); + EXPECT_EQ(weights.requestWeight(request), 50); + + request = boost::json::parse(R"json({"method": "method3", "ledger_index": "validated"})json").as_object(); + EXPECT_EQ(weights.requestWeight(request), 55); +} diff --git a/tests/unit/web/ng/RPCServerHandlerTests.cpp b/tests/unit/web/ng/RPCServerHandlerTests.cpp index abec3480..c3da8ed7 100644 --- a/tests/unit/web/ng/RPCServerHandlerTests.cpp +++ b/tests/unit/web/ng/RPCServerHandlerTests.cpp @@ -29,11 +29,13 @@ #include "util/newconfig/ConfigValue.hpp" #include "util/newconfig/Types.hpp" #include "web/SubscriptionContextInterface.hpp" +#include "web/dosguard/DOSGuardMock.hpp" #include "web/ng/MockConnection.hpp" #include "web/ng/RPCServerHandler.hpp" #include "web/ng/Request.hpp" #include +#include #include #include #include @@ -72,23 +74,84 @@ protected: std::shared_ptr> rpcEngine_ = std::make_shared>(); std::shared_ptr> etl_ = std::make_shared>(); - RPCServerHandler rpcServerHandler_{config, backend_, rpcEngine_, etl_}; + DOSGuardStrictMock dosguard_; + RPCServerHandler rpcServerHandler_{config, backend_, rpcEngine_, etl_, dosguard_}; util::TagDecoratorFactory tagFactory_{config}; - StrictMockConnectionMetadata connectionMetadata_{"some ip", tagFactory_}; + std::string const ip_ = "some ip"; + StrictMockConnectionMetadata connectionMetadata_{ip_, tagFactory_}; + Request::HttpHeaders const httpHeaders_; static Request makeHttpRequest(std::string_view body) { return Request{http::request{http::verb::post, "/", 11, body}}; } + + Request + makeWsRequest(std::string body) + { + return Request{std::move(body), httpHeaders_}; + } }; +TEST_F(NgRpcServerHandlerTest, DosguardRejectedHttpRequest) +{ + runSpawn([&](boost::asio::yield_context yield) { + auto const request = makeHttpRequest("some message"); + + EXPECT_CALL(dosguard_, isOk(ip_)).WillOnce(Return(false)); + auto response = rpcServerHandler_(request, connectionMetadata_, nullptr, yield); + + auto const responseHttp = std::move(response).intoHttpResponse(); + EXPECT_EQ(responseHttp.result(), http::status::service_unavailable); + + auto const responseJson = boost::json::parse(responseHttp.body()).as_object(); + EXPECT_EQ(responseJson.at("error_code").as_int64(), rpc::RippledError::rpcSLOW_DOWN); + }); +} + +TEST_F(NgRpcServerHandlerTest, DosguardRejectedWsRequest) +{ + runSpawn([&](boost::asio::yield_context yield) { + auto const requestStr = "some message"; + auto const request = makeWsRequest(requestStr); + + EXPECT_CALL(dosguard_, isOk(ip_)).WillOnce(Return(false)); + auto response = rpcServerHandler_(request, connectionMetadata_, nullptr, yield); + + auto const responseWs = boost::beast::buffers_to_string(response.asWsResponse()); + + auto const responseJson = boost::json::parse(responseWs).as_object(); + EXPECT_EQ(responseJson.at("error_code").as_int64(), rpc::RippledError::rpcSLOW_DOWN); + EXPECT_EQ(responseJson.at("request").as_string(), requestStr); + }); +} + +TEST_F(NgRpcServerHandlerTest, DosguardRejectedWsJsonRequest) +{ + runSpawn([&](boost::asio::yield_context yield) { + auto const requestStr = R"json({"request": "some message", "id": "some id"})json"; + auto const request = makeWsRequest(requestStr); + + EXPECT_CALL(dosguard_, isOk(ip_)).WillOnce(Return(false)); + auto response = rpcServerHandler_(request, connectionMetadata_, nullptr, yield); + + auto const responseWs = boost::beast::buffers_to_string(response.asWsResponse()); + + auto const responseJson = boost::json::parse(responseWs).as_object(); + EXPECT_EQ(responseJson.at("error_code").as_int64(), rpc::RippledError::rpcSLOW_DOWN); + EXPECT_EQ(responseJson.at("request").as_string(), requestStr); + EXPECT_EQ(responseJson.at("id").as_string(), "some id"); + }); +} + TEST_F(NgRpcServerHandlerTest, PostToRpcEngineFailed) { runSpawn([&](boost::asio::yield_context yield) { auto const request = makeHttpRequest("some message"); + EXPECT_CALL(dosguard_, isOk(ip_)).WillOnce(Return(true)); EXPECT_CALL(*rpcEngine_, post).WillOnce(Return(false)); EXPECT_CALL(*rpcEngine_, notifyTooBusy()); auto response = rpcServerHandler_(request, connectionMetadata_, nullptr, yield); @@ -107,6 +170,8 @@ TEST_F(NgRpcServerHandlerTest, CoroutineSleepsUntilRpcEngineFinishes) runSpawn([&](boost::asio::yield_context yield) { auto const request = makeHttpRequest("some message"); + EXPECT_CALL(dosguard_, isOk(ip_)).WillOnce(Return(true)); + EXPECT_CALL(dosguard_, add(ip_, testing::_)).WillOnce(Return(true)); EXPECT_CALL(*rpcEngine_, post).WillOnce([&](auto&& fn, auto&&) { boost::asio::spawn( ctx_, @@ -131,6 +196,8 @@ TEST_F(NgRpcServerHandlerTest, JsonParseFailed) runSpawn([&](boost::asio::yield_context yield) { auto const request = makeHttpRequest("not a json"); + EXPECT_CALL(dosguard_, isOk(ip_)).WillOnce(Return(true)); + EXPECT_CALL(dosguard_, add(ip_, testing::_)).WillOnce(Return(true)); EXPECT_CALL(*rpcEngine_, post).WillOnce([&](auto&& fn, auto&&) { EXPECT_CALL(*rpcEngine_, notifyBadSyntax); fn(yield); @@ -141,10 +208,61 @@ TEST_F(NgRpcServerHandlerTest, JsonParseFailed) }); } +TEST_F(NgRpcServerHandlerTest, DosguardRejectedParsedRequest) +{ + runSpawn([&](boost::asio::yield_context yield) { + std::string const requestStr = "{}"; + auto const request = makeHttpRequest(requestStr); + + EXPECT_CALL(dosguard_, isOk(ip_)).WillOnce(Return(true)); + EXPECT_CALL(dosguard_, request(ip_, boost::json::parse(requestStr).as_object())).WillOnce(Return(false)); + EXPECT_CALL(*rpcEngine_, post).WillOnce([&](auto&& fn, auto&&) { + fn(yield); + return true; + }); + EXPECT_CALL(dosguard_, add(ip_, testing::_)).WillOnce(Return(true)); + + auto response = rpcServerHandler_(request, connectionMetadata_, nullptr, yield); + auto const responseHttp = std::move(response).intoHttpResponse(); + EXPECT_EQ(responseHttp.result(), http::status::service_unavailable); + + auto const responseJson = boost::json::parse(responseHttp.body()).as_object(); + EXPECT_EQ(responseJson.at("error_code").as_int64(), rpc::RippledError::rpcSLOW_DOWN); + }); +} + +TEST_F(NgRpcServerHandlerTest, DosguardAddsLoadWarning) +{ + runSpawn([&](boost::asio::yield_context yield) { + std::string const requestStr = "{}"; + auto const request = makeHttpRequest(requestStr); + + EXPECT_CALL(dosguard_, isOk(ip_)).WillOnce(Return(true)); + EXPECT_CALL(dosguard_, request(ip_, boost::json::parse(requestStr).as_object())).WillOnce(Return(false)); + EXPECT_CALL(*rpcEngine_, post).WillOnce([&](auto&& fn, auto&&) { + fn(yield); + return true; + }); + EXPECT_CALL(dosguard_, add(ip_, testing::_)).WillOnce(Return(false)); + + auto response = rpcServerHandler_(request, connectionMetadata_, nullptr, yield); + auto const responseHttp = std::move(response).intoHttpResponse(); + EXPECT_EQ(responseHttp.result(), http::status::service_unavailable); + + auto const responseJson = boost::json::parse(responseHttp.body()).as_object(); + EXPECT_EQ(responseJson.at("error_code").as_int64(), rpc::RippledError::rpcSLOW_DOWN); + + EXPECT_EQ(responseJson.at("warning").as_string(), "load"); + EXPECT_EQ(responseJson.at("warnings").as_array().at(0).as_object().at("id").as_int64(), rpc::WarnRpcRateLimit); + }); +} + TEST_F(NgRpcServerHandlerTest, GotNotJsonObject) { runSpawn([&](boost::asio::yield_context yield) { auto const request = makeHttpRequest("[]"); + EXPECT_CALL(dosguard_, isOk(ip_)).WillOnce(Return(true)); + EXPECT_CALL(dosguard_, add(ip_, testing::_)).WillOnce(Return(true)); EXPECT_CALL(*rpcEngine_, post).WillOnce([&](auto&& fn, auto&&) { EXPECT_CALL(*rpcEngine_, notifyBadSyntax); fn(yield); @@ -158,8 +276,12 @@ TEST_F(NgRpcServerHandlerTest, GotNotJsonObject) TEST_F(NgRpcServerHandlerTest, HandleRequest_NoRangeFromBackend) { runSpawn([&](boost::asio::yield_context yield) { - auto const request = makeHttpRequest("{}"); + std::string const requestStr = "{}"; + auto const request = makeHttpRequest(requestStr); + EXPECT_CALL(dosguard_, isOk(ip_)).WillOnce(Return(true)); + EXPECT_CALL(dosguard_, request(ip_, boost::json::parse(requestStr).as_object())).WillOnce(Return(true)); + EXPECT_CALL(dosguard_, add(ip_, testing::_)).WillOnce(Return(true)); EXPECT_CALL(*rpcEngine_, post).WillOnce([&](auto&& fn, auto&&) { EXPECT_CALL(connectionMetadata_, wasUpgraded).WillOnce(Return(not request.isHttp())); EXPECT_CALL(*rpcEngine_, notifyNotReady); @@ -180,8 +302,12 @@ TEST_F(NgRpcServerHandlerTest, HandleRequest_ContextCreationFailed) { backend_->setRange(0, 1); runSpawn([&](boost::asio::yield_context yield) { - auto const request = makeHttpRequest("{}"); + std::string const requestStr = "{}"; + auto const request = makeHttpRequest(requestStr); + EXPECT_CALL(dosguard_, isOk(ip_)).WillOnce(Return(true)); + EXPECT_CALL(dosguard_, request(ip_, boost::json::parse(requestStr).as_object())).WillOnce(Return(true)); + EXPECT_CALL(dosguard_, add(ip_, testing::_)).WillOnce(Return(true)); EXPECT_CALL(*rpcEngine_, post).WillOnce([&](auto&& fn, auto&&) { EXPECT_CALL(connectionMetadata_, wasUpgraded).WillRepeatedly(Return(not request.isHttp())); EXPECT_CALL(*rpcEngine_, notifyBadSyntax); @@ -200,8 +326,12 @@ TEST_F(NgRpcServerHandlerTest, HandleRequest_BuildResponseFailed) { backend_->setRange(0, 1); runSpawn([&](boost::asio::yield_context yield) { - auto const request = makeHttpRequest(R"json({"method":"some_method"})json"); + std::string const requestStr = R"json({"method":"some_method"})json"; + auto const request = makeHttpRequest(requestStr); + EXPECT_CALL(dosguard_, isOk(ip_)).WillOnce(Return(true)); + EXPECT_CALL(dosguard_, request(ip_, boost::json::parse(requestStr).as_object())).WillOnce(Return(true)); + EXPECT_CALL(dosguard_, add(ip_, testing::_)).WillOnce(Return(true)); EXPECT_CALL(*rpcEngine_, post).WillOnce([&](auto&& fn, auto&&) { EXPECT_CALL(connectionMetadata_, wasUpgraded).WillRepeatedly(Return(not request.isHttp())); EXPECT_CALL(*rpcEngine_, buildResponse) @@ -227,8 +357,12 @@ TEST_F(NgRpcServerHandlerTest, HandleRequest_BuildResponseThrewAnException) { backend_->setRange(0, 1); runSpawn([&](boost::asio::yield_context yield) { - auto const request = makeHttpRequest(R"json({"method":"some_method"})json"); + std::string const requestStr = R"json({"method":"some_method"})json"; + auto const request = makeHttpRequest(requestStr); + EXPECT_CALL(dosguard_, isOk(ip_)).WillOnce(Return(true)); + EXPECT_CALL(dosguard_, request(ip_, boost::json::parse(requestStr).as_object())).WillOnce(Return(true)); + EXPECT_CALL(dosguard_, add(ip_, testing::_)).WillOnce(Return(true)); EXPECT_CALL(*rpcEngine_, post).WillOnce([&](auto&& fn, auto&&) { EXPECT_CALL(connectionMetadata_, wasUpgraded).WillRepeatedly(Return(not request.isHttp())); EXPECT_CALL(*rpcEngine_, buildResponse).WillOnce([](auto&&) -> rpc::Result { @@ -249,8 +383,12 @@ TEST_F(NgRpcServerHandlerTest, HandleRequest_Successful_HttpRequest) { backend_->setRange(0, 1); runSpawn([&](boost::asio::yield_context yield) { - auto const request = makeHttpRequest(R"json({"method":"some_method"})json"); + std::string const requestStr = R"json({"method":"some_method"})json"; + auto const request = makeHttpRequest(requestStr); + EXPECT_CALL(dosguard_, isOk(ip_)).WillOnce(Return(true)); + EXPECT_CALL(dosguard_, request(ip_, boost::json::parse(requestStr).as_object())).WillOnce(Return(true)); + EXPECT_CALL(dosguard_, add(ip_, testing::_)).WillOnce(Return(true)); EXPECT_CALL(*rpcEngine_, post).WillOnce([&](auto&& fn, auto&&) { EXPECT_CALL(connectionMetadata_, wasUpgraded).WillRepeatedly(Return(not request.isHttp())); EXPECT_CALL(*rpcEngine_, buildResponse) @@ -278,8 +416,12 @@ TEST_F(NgRpcServerHandlerTest, HandleRequest_OutdatedWarning) { backend_->setRange(0, 1); runSpawn([&](boost::asio::yield_context yield) { - auto const request = makeHttpRequest(R"json({"method":"some_method"})json"); + std::string const requestStr = R"json({"method":"some_method"})json"; + auto const request = makeHttpRequest(requestStr); + EXPECT_CALL(dosguard_, isOk(ip_)).WillOnce(Return(true)); + EXPECT_CALL(dosguard_, request(ip_, boost::json::parse(requestStr).as_object())).WillOnce(Return(true)); + EXPECT_CALL(dosguard_, add(ip_, testing::_)).WillOnce(Return(true)); EXPECT_CALL(*rpcEngine_, post).WillOnce([&](auto&& fn, auto&&) { EXPECT_CALL(connectionMetadata_, wasUpgraded).WillRepeatedly(Return(not request.isHttp())); EXPECT_CALL(*rpcEngine_, buildResponse) @@ -313,8 +455,12 @@ TEST_F(NgRpcServerHandlerTest, HandleRequest_Successful_HttpRequest_Forwarded) { backend_->setRange(0, 1); runSpawn([&](boost::asio::yield_context yield) { - auto const request = makeHttpRequest(R"json({"method":"some_method"})json"); + std::string const requestStr = R"json({"method":"some_method"})json"; + auto const request = makeHttpRequest(requestStr); + EXPECT_CALL(dosguard_, isOk(ip_)).WillOnce(Return(true)); + EXPECT_CALL(dosguard_, request(ip_, boost::json::parse(requestStr).as_object())).WillOnce(Return(true)); + EXPECT_CALL(dosguard_, add(ip_, testing::_)).WillOnce(Return(true)); EXPECT_CALL(*rpcEngine_, post).WillOnce([&](auto&& fn, auto&&) { EXPECT_CALL(connectionMetadata_, wasUpgraded).WillRepeatedly(Return(not request.isHttp())); EXPECT_CALL(*rpcEngine_, buildResponse) @@ -345,8 +491,12 @@ TEST_F(NgRpcServerHandlerTest, HandleRequest_Successful_HttpRequest_HasError) { backend_->setRange(0, 1); runSpawn([&](boost::asio::yield_context yield) { - auto const request = makeHttpRequest(R"json({"method":"some_method"})json"); + std::string const requestStr = R"json({"method":"some_method"})json"; + auto const request = makeHttpRequest(requestStr); + EXPECT_CALL(dosguard_, isOk(ip_)).WillOnce(Return(true)); + EXPECT_CALL(dosguard_, request(ip_, boost::json::parse(requestStr).as_object())).WillOnce(Return(true)); + EXPECT_CALL(dosguard_, add(ip_, testing::_)).WillOnce(Return(true)); EXPECT_CALL(*rpcEngine_, post).WillOnce([&](auto&& fn, auto&&) { EXPECT_CALL(connectionMetadata_, wasUpgraded).WillRepeatedly(Return(not request.isHttp())); EXPECT_CALL(*rpcEngine_, buildResponse) @@ -393,8 +543,12 @@ TEST_F(NgRpcServerHandlerWsTest, HandleRequest_Successful_WsRequest) backend_->setRange(0, 1); runSpawn([&](boost::asio::yield_context yield) { Request::HttpHeaders const headers; - auto const request = Request(R"json({"method":"some_method", "id": 1234, "api_version": 1})json", headers); + std::string const requestStr = R"json({"method":"some_method", "id": 1234, "api_version": 1})json"; + auto const request = Request(requestStr, headers); + EXPECT_CALL(dosguard_, isOk(ip_)).WillOnce(Return(true)); + EXPECT_CALL(dosguard_, request(ip_, boost::json::parse(requestStr).as_object())).WillOnce(Return(true)); + EXPECT_CALL(dosguard_, add(ip_, testing::_)).WillOnce(Return(true)); EXPECT_CALL(*rpcEngine_, post).WillOnce([&](auto&& fn, auto&&) { EXPECT_CALL(connectionMetadata_, wasUpgraded).WillRepeatedly(Return(not request.isHttp())); EXPECT_CALL(*rpcEngine_, buildResponse) @@ -424,8 +578,12 @@ TEST_F(NgRpcServerHandlerWsTest, HandleRequest_Successful_WsRequest_HasError) backend_->setRange(0, 1); runSpawn([&](boost::asio::yield_context yield) { Request::HttpHeaders const headers; - auto const request = Request(R"json({"method":"some_method", "id": 1234, "api_version": 1})json", headers); + std::string const requestStr = R"json({"method":"some_method", "id": 1234, "api_version": 1})json"; + auto const request = Request(requestStr, headers); + EXPECT_CALL(dosguard_, isOk(ip_)).WillOnce(Return(true)); + EXPECT_CALL(dosguard_, request(ip_, boost::json::parse(requestStr).as_object())).WillOnce(Return(true)); + EXPECT_CALL(dosguard_, add(ip_, testing::_)).WillOnce(Return(true)); EXPECT_CALL(*rpcEngine_, post).WillOnce([&](auto&& fn, auto&&) { EXPECT_CALL(connectionMetadata_, wasUpgraded).WillRepeatedly(Return(not request.isHttp())); EXPECT_CALL(*rpcEngine_, buildResponse) diff --git a/tools/requests_gun/internal/request_maker/request_maker.go b/tools/requests_gun/internal/request_maker/request_maker.go index 92088c77..0c23bd99 100644 --- a/tools/requests_gun/internal/request_maker/request_maker.go +++ b/tools/requests_gun/internal/request_maker/request_maker.go @@ -74,6 +74,7 @@ func NewHttp(host string, port uint) *HttpRequestMaker { host = "http://" + host } transport := http.DefaultTransport.(*http.Transport).Clone() + transport.DisableKeepAlives = true client := &http.Client{Transport: transport} return &HttpRequestMaker{host + ":" + fmt.Sprintf("%d", port), transport, client}