diff --git a/src/webserver/DOSGuard.h b/src/webserver/DOSGuard.h index 4ae34f58..f7f579e2 100644 --- a/src/webserver/DOSGuard.h +++ b/src/webserver/DOSGuard.h @@ -12,16 +12,19 @@ class DOSGuard boost::asio::io_context& ctx_; std::mutex mtx_; // protects ipFetchCount_ std::unordered_map ipFetchCount_; + std::unordered_map ipConnCount_; std::unordered_set const whitelist_; std::uint32_t const maxFetches_; std::uint32_t const sweepInterval_; + std::uint32_t const maxConnCount_; public: DOSGuard(clio::Config const& config, boost::asio::io_context& ctx) : ctx_{ctx} , whitelist_{getWhitelist(config)} - , maxFetches_{config.valueOr("dos_guard.max_fetches", 100u)} - , sweepInterval_{config.valueOr("dos_guard.sweep_interval", 1u)} + , maxFetches_{config.valueOr("dos_guard.max_fetches", 100000000u)} + , sweepInterval_{config.valueOr("dos_guard.sweep_interval", 10u)} + , maxConnCount_{config.valueOr("dos_guard.max_connections", 1u)} { createTimer(); } @@ -53,11 +56,44 @@ public: return true; std::unique_lock lck(mtx_); - auto it = ipFetchCount_.find(ip); - if (it == ipFetchCount_.end()) - return true; + bool fetchesOk = true; + bool connsOk = true; + { + auto it = ipFetchCount_.find(ip); + if (it != ipFetchCount_.end()) + fetchesOk = it->second <= maxFetches_; + } + { + auto it = ipConnCount_.find(ip); + assert(it != ipConnCount_.end()); + if (it != ipConnCount_.end()) + { + connsOk = it->second <= maxConnCount_; + } + } - return it->second < maxFetches_; + return fetchesOk && connsOk; + } + + void + increment(std::string const& ip) + { + if (whitelist_.contains(ip)) + return; + std::unique_lock lck{mtx_}; + ipConnCount_[ip]++; + } + + void + decrement(std::string const& ip) + { + if (whitelist_.contains(ip)) + return; + std::unique_lock lck{mtx_}; + assert(ipConnCount_[ip] > 0); + ipConnCount_[ip]--; + if (ipConnCount_[ip] == 0) + ipConnCount_.erase(ip); } bool diff --git a/src/webserver/HttpBase.h b/src/webserver/HttpBase.h index 0b62859b..49f66362 100644 --- a/src/webserver/HttpBase.h +++ b/src/webserver/HttpBase.h @@ -104,6 +104,7 @@ class HttpBase : public util::Taggable protected: clio::Logger perfLog_{"Performance"}; boost::beast::flat_buffer buffer_; + bool upgraded_ = false; bool dead() @@ -171,11 +172,18 @@ public: { perfLog_.debug() << tag() << "http session created"; } + virtual ~HttpBase() { perfLog_.debug() << tag() << "http session closed"; } + DOSGuard& + dosGuard() + { + return dosGuard_; + } + void do_read() { @@ -212,12 +220,14 @@ public: if (boost::beast::websocket::is_upgrade(req_)) { + upgraded_ = true; // Disable the timeout. // The websocket::stream uses its own timeout settings. boost::beast::get_lowest_layer(derived().stream()).expires_never(); return make_websocket_session( ioc_, derived().release_stream(), + derived().ip(), std::move(req_), std::move(buffer_), backend_, diff --git a/src/webserver/HttpSession.h b/src/webserver/HttpSession.h index ce3fe987..2c7c9b1d 100644 --- a/src/webserver/HttpSession.h +++ b/src/webserver/HttpSession.h @@ -12,6 +12,7 @@ class HttpSession : public HttpBase, public std::enable_shared_from_this { boost::beast::tcp_stream stream_; + std::optional ip_; public: // Take ownership of the socket @@ -40,6 +41,21 @@ public: std::move(buffer)) , stream_(std::move(socket)) { + try + { + ip_ = stream_.socket().remote_endpoint().address().to_string(); + } + catch (std::exception const&) + { + } + if (ip_) + HttpBase::dosGuard().increment(*ip_); + } + + ~HttpSession() + { + if (ip_ and not upgraded_) + HttpBase::dosGuard().decrement(*ip_); } boost::beast::tcp_stream& @@ -56,14 +72,7 @@ public: std::optional ip() { - try - { - return stream_.socket().remote_endpoint().address().to_string(); - } - catch (std::exception const&) - { - return {}; - } + return ip_; } // Start the asynchronous operation diff --git a/src/webserver/Listener.h b/src/webserver/Listener.h index 1f479c3d..bdde4a69 100644 --- a/src/webserver/Listener.h +++ b/src/webserver/Listener.h @@ -137,6 +137,7 @@ void make_websocket_session( boost::asio::io_context& ioc, boost::beast::tcp_stream stream, + std::optional const& ip, http::request req, boost::beast::flat_buffer buffer, std::shared_ptr backend, @@ -151,6 +152,7 @@ make_websocket_session( std::make_shared( ioc, std::move(stream), + ip, backend, subscriptions, balancer, @@ -168,6 +170,7 @@ void make_websocket_session( boost::asio::io_context& ioc, boost::beast::ssl_stream stream, + std::optional const& ip, http::request req, boost::beast::flat_buffer buffer, std::shared_ptr backend, @@ -182,6 +185,7 @@ make_websocket_session( std::make_shared( ioc, std::move(stream), + ip, backend, subscriptions, balancer, diff --git a/src/webserver/PlainWsSession.h b/src/webserver/PlainWsSession.h index 62012ca6..6862452f 100644 --- a/src/webserver/PlainWsSession.h +++ b/src/webserver/PlainWsSession.h @@ -31,6 +31,7 @@ public: explicit PlainWsSession( boost::asio::io_context& ioc, boost::asio::ip::tcp::socket&& socket, + std::optional ip, std::shared_ptr backend, std::shared_ptr subscriptions, std::shared_ptr balancer, @@ -42,6 +43,7 @@ public: boost::beast::flat_buffer&& buffer) : WsSession( ioc, + ip, backend, subscriptions, balancer, @@ -64,19 +66,7 @@ public: std::optional ip() { - try - { - return ws() - .next_layer() - .socket() - .remote_endpoint() - .address() - .to_string(); - } - catch (std::exception const&) - { - return {}; - } + return ip_; } ~PlainWsSession() = default; @@ -97,11 +87,13 @@ class WsUpgrader : public std::enable_shared_from_this RPC::Counters& counters_; WorkQueue& queue_; http::request req_; + std::optional ip_; public: WsUpgrader( boost::asio::io_context& ioc, boost::asio::ip::tcp::socket&& socket, + std::optional ip, std::shared_ptr backend, std::shared_ptr subscriptions, std::shared_ptr balancer, @@ -122,11 +114,13 @@ public: , dosGuard_(dosGuard) , counters_(counters) , queue_(queue) + , ip_(ip) { } WsUpgrader( boost::asio::io_context& ioc, boost::beast::tcp_stream&& stream, + std::optional ip, std::shared_ptr backend, std::shared_ptr subscriptions, std::shared_ptr balancer, @@ -149,6 +143,7 @@ public: , counters_(counters) , queue_(queue) , req_(std::move(req)) + , ip_(ip) { } @@ -197,6 +192,7 @@ private: std::make_shared( ioc_, http_.release_socket(), + ip_, backend_, subscriptions_, balancer_, diff --git a/src/webserver/SslHttpSession.h b/src/webserver/SslHttpSession.h index e25141e9..fe46727d 100644 --- a/src/webserver/SslHttpSession.h +++ b/src/webserver/SslHttpSession.h @@ -12,6 +12,7 @@ class SslHttpSession : public HttpBase, public std::enable_shared_from_this { boost::beast::ssl_stream stream_; + std::optional ip_; public: // Take ownership of the socket @@ -41,6 +42,25 @@ public: std::move(buffer)) , stream_(std::move(socket), ctx) { + try + { + ip_ = stream_.next_layer() + .socket() + .remote_endpoint() + .address() + .to_string(); + } + catch (std::exception const&) + { + } + if (ip_) + HttpBase::dosGuard().increment(*ip_); + } + + ~SslHttpSession() + { + if (ip_ and not upgraded_) + HttpBase::dosGuard().decrement(*ip_); } boost::beast::ssl_stream& @@ -57,18 +77,7 @@ public: std::optional ip() { - try - { - return stream_.next_layer() - .socket() - .remote_endpoint() - .address() - .to_string(); - } - catch (std::exception const&) - { - return {}; - } + return ip_; } // Start the asynchronous operation diff --git a/src/webserver/SslWsSession.h b/src/webserver/SslWsSession.h index 39532d03..7cb8a8f6 100644 --- a/src/webserver/SslWsSession.h +++ b/src/webserver/SslWsSession.h @@ -29,6 +29,7 @@ public: explicit SslWsSession( boost::asio::io_context& ioc, boost::beast::ssl_stream&& stream, + std::optional ip, std::shared_ptr backend, std::shared_ptr subscriptions, std::shared_ptr balancer, @@ -40,6 +41,7 @@ public: boost::beast::flat_buffer&& b) : WsSession( ioc, + ip, backend, subscriptions, balancer, @@ -62,20 +64,7 @@ public: std::optional ip() { - try - { - return ws() - .next_layer() - .next_layer() - .socket() - .remote_endpoint() - .address() - .to_string(); - } - catch (std::exception const&) - { - return {}; - } + return ip_; } }; @@ -85,6 +74,7 @@ class SslWsUpgrader : public std::enable_shared_from_this boost::beast::ssl_stream https_; boost::optional> parser_; boost::beast::flat_buffer buffer_; + std::optional ip_; std::shared_ptr backend_; std::shared_ptr subscriptions_; std::shared_ptr balancer_; @@ -98,6 +88,7 @@ class SslWsUpgrader : public std::enable_shared_from_this public: SslWsUpgrader( boost::asio::io_context& ioc, + std::optional ip, boost::asio::ip::tcp::socket&& socket, ssl::context& ctx, std::shared_ptr backend, @@ -112,6 +103,7 @@ public: : ioc_(ioc) , https_(std::move(socket), ctx) , buffer_(std::move(b)) + , ip_(ip) , backend_(backend) , subscriptions_(subscriptions) , balancer_(balancer) @@ -125,6 +117,7 @@ public: SslWsUpgrader( boost::asio::io_context& ioc, boost::beast::ssl_stream stream, + std::optional ip, std::shared_ptr backend, std::shared_ptr subscriptions, std::shared_ptr balancer, @@ -138,6 +131,7 @@ public: : ioc_(ioc) , https_(std::move(stream)) , buffer_(std::move(b)) + , ip_(ip) , backend_(backend) , subscriptions_(subscriptions) , balancer_(balancer) @@ -210,6 +204,7 @@ private: std::make_shared( ioc_, std::move(https_), + ip_, backend_, subscriptions_, balancer_, diff --git a/src/webserver/WsBase.h b/src/webserver/WsBase.h index 14de053a..05dee77e 100644 --- a/src/webserver/WsBase.h +++ b/src/webserver/WsBase.h @@ -111,6 +111,9 @@ class WsSession : public WsBase, bool sending_ = false; std::queue> messages_; +protected: + std::optional ip_; + void wsFail(boost::beast::error_code ec, char const* what) { @@ -128,6 +131,7 @@ class WsSession : public WsBase, public: explicit WsSession( boost::asio::io_context& ioc, + std::optional ip, std::shared_ptr backend, std::shared_ptr subscriptions, std::shared_ptr balancer, @@ -148,12 +152,16 @@ public: , dosGuard_(dosGuard) , counters_(counters) , queue_(queue) + , ip_(ip) { perfLog_.info() << tag() << "session created"; } + virtual ~WsSession() { perfLog_.info() << tag() << "session closed"; + if (ip_) + dosGuard_.decrement(*ip_); } // Access the derived class, this is part of