diff --git a/src/webserver/DOSGuard.h b/src/webserver/DOSGuard.h index 9d5bd4a3..902881d4 100644 --- a/src/webserver/DOSGuard.h +++ b/src/webserver/DOSGuard.h @@ -48,13 +48,24 @@ public: template class BasicDOSGuard : public BaseDOSGuard { - mutable std::mutex mtx_; // protects ipFetchCount_ - std::unordered_map ipFetchCount_; + // Accumulated state per IP, state will be reset accordingly + struct ClientState + { + // accumulated transfered byte + std::uint32_t transferedByte = 0; + // accumulated served requests count + std::uint32_t requestsCount = 0; + }; + + mutable std::mutex mtx_; + // accumulated states map + std::unordered_map ipState_; std::unordered_map ipConnCount_; std::unordered_set const whitelist_; std::uint32_t const maxFetches_; std::uint32_t const maxConnCount_; + std::uint32_t const maxRequestCount_; clio::Logger log_{"RPC"}; public: @@ -68,6 +79,7 @@ public: : whitelist_{getWhitelist(config)} , maxFetches_{config.valueOr("dos_guard.max_fetches", 100000000u)} , maxConnCount_{config.valueOr("dos_guard.max_connections", 1u)} + , maxRequestCount_{config.valueOr("dos_guard.max_requests", 10u)} { sweepHandler.setup(this); } @@ -98,25 +110,33 @@ public: if (whitelist_.contains(ip)) return true; - std::unique_lock lck(mtx_); - bool fetchesOk = true; - bool connsOk = true; - { - auto it = ipFetchCount_.find(ip); - if (it != ipFetchCount_.end()) - fetchesOk = it->second <= maxFetches_; - } { + std::unique_lock lck(mtx_); + if (ipState_.find(ip) != ipState_.end()) + { + auto [transferedByte, requests] = ipState_.at(ip); + if (transferedByte > maxFetches_ || requests > maxRequestCount_) + { + log_.warn() + << "Dosguard:Client surpassed the rate limit. ip = " + << ip << " Transfered Byte:" << transferedByte + << " Requests:" << requests; + return false; + } + } auto it = ipConnCount_.find(ip); if (it != ipConnCount_.end()) { - connsOk = it->second <= maxConnCount_; + if (it->second > maxConnCount_) + { + log_.warn() + << "Dosguard:Client surpassed the rate limit. ip = " + << ip << " Concurrent connection:" << it->second; + return false; + } } } - if (!fetchesOk || !connsOk) - log_.warn() << "Client surpassed the rate limit. ip = " << ip; - - return fetchesOk && connsOk; + return true; } /** @@ -170,11 +190,32 @@ public: { std::unique_lock lck(mtx_); - auto it = ipFetchCount_.find(ip); - if (it == ipFetchCount_.end()) - ipFetchCount_[ip] = numObjects; - else - it->second += numObjects; + ipState_[ip].transferedByte += numObjects; + } + + return isOk(ip); + } + + /** + * @brief Adds one request for the given ip address. + * + * If the total sums up to a value equal or larger than maxRequestCount_ + * the operation is no longer allowed and false is returned; true is + * returned otherwise. + * + * @param ip + * @return true + * @return false + */ + [[maybe_unused]] bool + request(std::string const& ip) noexcept + { + if (whitelist_.contains(ip)) + return true; + + { + std::unique_lock lck(mtx_); + ipState_[ip].requestsCount++; } return isOk(ip); @@ -188,7 +229,7 @@ public: clear() noexcept override { std::unique_lock lck(mtx_); - ipFetchCount_.clear(); + ipState_.clear(); } private: diff --git a/src/webserver/HttpBase.h b/src/webserver/HttpBase.h index 0e07799f..bed3ec7e 100644 --- a/src/webserver/HttpBase.h +++ b/src/webserver/HttpBase.h @@ -238,6 +238,27 @@ public: if (ec) return httpFail(ec, "read"); + auto ip = derived().ip(); + + if (!ip) + { + return; + } + + auto const httpResponse = [&](http::status status, + std::string content_type, + std::string message) { + http::response res{status, req_.version()}; + res.set( + http::field::server, + "clio-server-" + Build::getClioVersionString()); + res.set(http::field::content_type, content_type); + res.keep_alive(req_.keep_alive()); + res.body() = std::string(message); + res.prepare_payload(); + return res; + }; + if (boost::beast::websocket::is_upgrade(req_)) { upgraded_ = true; @@ -260,10 +281,16 @@ public: workQueue_); } - auto ip = derived().ip(); - - if (!ip) - return; + // to avoid overwhelm work queue, the request limit check should be + // before posting to queue the web socket creation will be guarded via + // connection limit + if (!dosGuard_.request(ip.value())) + { + return lambda_(httpResponse( + http::status::service_unavailable, + "text/plain", + "Server is overloaded")); + } perfLog_.debug() << tag() << "Received request from ip = " << *ip << " - posting to WorkQueue"; @@ -293,17 +320,11 @@ public: { // Non-whitelist connection rejected due to full connection // queue - http::response res{ - http::status::ok, req_.version()}; - res.set( - http::field::server, - "clio-server-" + Build::getClioVersionString()); - res.set(http::field::content_type, "application/json"); - res.keep_alive(req_.keep_alive()); - res.body() = boost::json::serialize( - RPC::makeError(RPC::RippledError::rpcTOO_BUSY)); - res.prepare_payload(); - lambda_(std::move(res)); + lambda_(httpResponse( + http::status::ok, + "application/json", + boost::json::serialize( + RPC::makeError(RPC::RippledError::rpcTOO_BUSY)))); } } @@ -380,12 +401,6 @@ handle_request( return send(httpResponse( http::status::bad_request, "text/html", "Expected a POST request")); - if (!dosGuard.isOk(ip)) - return send(httpResponse( - http::status::service_unavailable, - "text/plain", - "Server is overloaded")); - try { perfLog.debug() << http->tag() diff --git a/src/webserver/WsBase.h b/src/webserver/WsBase.h index e983597d..f7422b0f 100644 --- a/src/webserver/WsBase.h +++ b/src/webserver/WsBase.h @@ -452,19 +452,22 @@ public: }(std::move(msg)); boost::json::object request; - if (!raw.is_object()) - return sendError( - RPC::RippledError::rpcINVALID_PARAMS, nullptr, request); - request = raw.as_object(); - - auto id = request.contains("id") ? request.at("id") : nullptr; - - if (!dosGuard_.isOk(*ip)) + // dosGuard served request++ and check ip address + // dosGuard should check before any request, even invalid request + if (!dosGuard_.request(*ip)) { - sendError(RPC::RippledError::rpcSLOW_DOWN, id, request); + sendError(RPC::RippledError::rpcSLOW_DOWN, nullptr, request); + } + else if (!raw.is_object()) + { + // handle invalid request and async read again + sendError(RPC::RippledError::rpcINVALID_PARAMS, nullptr, request); } else { + request = raw.as_object(); + + auto id = request.contains("id") ? request.at("id") : nullptr; perfLog_.debug() << tag() << "Adding to work queue"; if (!queue_.postCoro( diff --git a/unittests/DOSGuard.cpp b/unittests/DOSGuard.cpp index 7053135b..3f629eb0 100644 --- a/unittests/DOSGuard.cpp +++ b/unittests/DOSGuard.cpp @@ -37,6 +37,7 @@ constexpr static auto JSONData = R"JSON( "max_fetches": 100, "sweep_interval": 1, "max_connections": 2, + "max_requests": 3, "whitelist": ["127.0.0.1"] } } @@ -126,6 +127,30 @@ TEST_F(DOSGuardTest, ClearFetchCountOnTimer) EXPECT_TRUE(guard.isOk(IP)); // can fetch again } +TEST_F(DOSGuardTest, RequestLimit) +{ + EXPECT_TRUE(guard.request(IP)); + EXPECT_TRUE(guard.request(IP)); + EXPECT_TRUE(guard.request(IP)); + EXPECT_TRUE(guard.isOk(IP)); + EXPECT_FALSE(guard.request(IP)); + EXPECT_FALSE(guard.isOk(IP)); + guard.clear(); + EXPECT_TRUE(guard.isOk(IP)); // can request again +} + +TEST_F(DOSGuardTest, RequestLimitOnTimer) +{ + EXPECT_TRUE(guard.request(IP)); + EXPECT_TRUE(guard.request(IP)); + EXPECT_TRUE(guard.request(IP)); + EXPECT_TRUE(guard.isOk(IP)); + EXPECT_FALSE(guard.request(IP)); + EXPECT_FALSE(guard.isOk(IP)); + sweepHandler.sweep(); + EXPECT_TRUE(guard.isOk(IP)); // can request again +} + template struct BasicDOSGuardMock : public BaseDOSGuard {