add connection counting (#433)

This commit is contained in:
CJ Cobb
2022-12-07 13:12:38 -05:00
committed by GitHub
parent 3ec5755930
commit 8a1f00debb
8 changed files with 120 additions and 53 deletions

View File

@@ -12,16 +12,19 @@ class DOSGuard
boost::asio::io_context& ctx_; boost::asio::io_context& ctx_;
std::mutex mtx_; // protects ipFetchCount_ std::mutex mtx_; // protects ipFetchCount_
std::unordered_map<std::string, std::uint32_t> ipFetchCount_; std::unordered_map<std::string, std::uint32_t> ipFetchCount_;
std::unordered_map<std::string, std::uint32_t> ipConnCount_;
std::unordered_set<std::string> const whitelist_; std::unordered_set<std::string> const whitelist_;
std::uint32_t const maxFetches_; std::uint32_t const maxFetches_;
std::uint32_t const sweepInterval_; std::uint32_t const sweepInterval_;
std::uint32_t const maxConnCount_;
public: public:
DOSGuard(clio::Config const& config, boost::asio::io_context& ctx) DOSGuard(clio::Config const& config, boost::asio::io_context& ctx)
: ctx_{ctx} : ctx_{ctx}
, whitelist_{getWhitelist(config)} , whitelist_{getWhitelist(config)}
, maxFetches_{config.valueOr("dos_guard.max_fetches", 100u)} , maxFetches_{config.valueOr("dos_guard.max_fetches", 100000000u)}
, sweepInterval_{config.valueOr("dos_guard.sweep_interval", 1u)} , sweepInterval_{config.valueOr("dos_guard.sweep_interval", 10u)}
, maxConnCount_{config.valueOr("dos_guard.max_connections", 1u)}
{ {
createTimer(); createTimer();
} }
@@ -53,11 +56,44 @@ public:
return true; return true;
std::unique_lock lck(mtx_); std::unique_lock lck(mtx_);
auto it = ipFetchCount_.find(ip); bool fetchesOk = true;
if (it == ipFetchCount_.end()) bool connsOk = true;
return true; {
auto it = ipFetchCount_.find(ip);
if (it != ipFetchCount_.end())
fetchesOk = it->second <= maxFetches_;
}
{
auto it = ipConnCount_.find(ip);
assert(it != ipConnCount_.end());
if (it != ipConnCount_.end())
{
connsOk = it->second <= maxConnCount_;
}
}
return it->second < maxFetches_; return fetchesOk && connsOk;
}
void
increment(std::string const& ip)
{
if (whitelist_.contains(ip))
return;
std::unique_lock lck{mtx_};
ipConnCount_[ip]++;
}
void
decrement(std::string const& ip)
{
if (whitelist_.contains(ip))
return;
std::unique_lock lck{mtx_};
assert(ipConnCount_[ip] > 0);
ipConnCount_[ip]--;
if (ipConnCount_[ip] == 0)
ipConnCount_.erase(ip);
} }
bool bool

View File

@@ -104,6 +104,7 @@ class HttpBase : public util::Taggable
protected: protected:
clio::Logger perfLog_{"Performance"}; clio::Logger perfLog_{"Performance"};
boost::beast::flat_buffer buffer_; boost::beast::flat_buffer buffer_;
bool upgraded_ = false;
bool bool
dead() dead()
@@ -171,11 +172,18 @@ public:
{ {
perfLog_.debug() << tag() << "http session created"; perfLog_.debug() << tag() << "http session created";
} }
virtual ~HttpBase() virtual ~HttpBase()
{ {
perfLog_.debug() << tag() << "http session closed"; perfLog_.debug() << tag() << "http session closed";
} }
DOSGuard&
dosGuard()
{
return dosGuard_;
}
void void
do_read() do_read()
{ {
@@ -212,12 +220,14 @@ public:
if (boost::beast::websocket::is_upgrade(req_)) if (boost::beast::websocket::is_upgrade(req_))
{ {
upgraded_ = true;
// Disable the timeout. // Disable the timeout.
// The websocket::stream uses its own timeout settings. // The websocket::stream uses its own timeout settings.
boost::beast::get_lowest_layer(derived().stream()).expires_never(); boost::beast::get_lowest_layer(derived().stream()).expires_never();
return make_websocket_session( return make_websocket_session(
ioc_, ioc_,
derived().release_stream(), derived().release_stream(),
derived().ip(),
std::move(req_), std::move(req_),
std::move(buffer_), std::move(buffer_),
backend_, backend_,

View File

@@ -12,6 +12,7 @@ class HttpSession : public HttpBase<HttpSession>,
public std::enable_shared_from_this<HttpSession> public std::enable_shared_from_this<HttpSession>
{ {
boost::beast::tcp_stream stream_; boost::beast::tcp_stream stream_;
std::optional<std::string> ip_;
public: public:
// Take ownership of the socket // Take ownership of the socket
@@ -40,6 +41,21 @@ public:
std::move(buffer)) std::move(buffer))
, stream_(std::move(socket)) , stream_(std::move(socket))
{ {
try
{
ip_ = stream_.socket().remote_endpoint().address().to_string();
}
catch (std::exception const&)
{
}
if (ip_)
HttpBase::dosGuard().increment(*ip_);
}
~HttpSession()
{
if (ip_ and not upgraded_)
HttpBase::dosGuard().decrement(*ip_);
} }
boost::beast::tcp_stream& boost::beast::tcp_stream&
@@ -56,14 +72,7 @@ public:
std::optional<std::string> std::optional<std::string>
ip() ip()
{ {
try return ip_;
{
return stream_.socket().remote_endpoint().address().to_string();
}
catch (std::exception const&)
{
return {};
}
} }
// Start the asynchronous operation // Start the asynchronous operation

View File

@@ -137,6 +137,7 @@ void
make_websocket_session( make_websocket_session(
boost::asio::io_context& ioc, boost::asio::io_context& ioc,
boost::beast::tcp_stream stream, boost::beast::tcp_stream stream,
std::optional<std::string> const& ip,
http::request<http::string_body> req, http::request<http::string_body> req,
boost::beast::flat_buffer buffer, boost::beast::flat_buffer buffer,
std::shared_ptr<BackendInterface const> backend, std::shared_ptr<BackendInterface const> backend,
@@ -151,6 +152,7 @@ make_websocket_session(
std::make_shared<WsUpgrader>( std::make_shared<WsUpgrader>(
ioc, ioc,
std::move(stream), std::move(stream),
ip,
backend, backend,
subscriptions, subscriptions,
balancer, balancer,
@@ -168,6 +170,7 @@ void
make_websocket_session( make_websocket_session(
boost::asio::io_context& ioc, boost::asio::io_context& ioc,
boost::beast::ssl_stream<boost::beast::tcp_stream> stream, boost::beast::ssl_stream<boost::beast::tcp_stream> stream,
std::optional<std::string> const& ip,
http::request<http::string_body> req, http::request<http::string_body> req,
boost::beast::flat_buffer buffer, boost::beast::flat_buffer buffer,
std::shared_ptr<BackendInterface const> backend, std::shared_ptr<BackendInterface const> backend,
@@ -182,6 +185,7 @@ make_websocket_session(
std::make_shared<SslWsUpgrader>( std::make_shared<SslWsUpgrader>(
ioc, ioc,
std::move(stream), std::move(stream),
ip,
backend, backend,
subscriptions, subscriptions,
balancer, balancer,

View File

@@ -31,6 +31,7 @@ public:
explicit PlainWsSession( explicit PlainWsSession(
boost::asio::io_context& ioc, boost::asio::io_context& ioc,
boost::asio::ip::tcp::socket&& socket, boost::asio::ip::tcp::socket&& socket,
std::optional<std::string> ip,
std::shared_ptr<BackendInterface const> backend, std::shared_ptr<BackendInterface const> backend,
std::shared_ptr<SubscriptionManager> subscriptions, std::shared_ptr<SubscriptionManager> subscriptions,
std::shared_ptr<ETLLoadBalancer> balancer, std::shared_ptr<ETLLoadBalancer> balancer,
@@ -42,6 +43,7 @@ public:
boost::beast::flat_buffer&& buffer) boost::beast::flat_buffer&& buffer)
: WsSession( : WsSession(
ioc, ioc,
ip,
backend, backend,
subscriptions, subscriptions,
balancer, balancer,
@@ -64,19 +66,7 @@ public:
std::optional<std::string> std::optional<std::string>
ip() ip()
{ {
try return ip_;
{
return ws()
.next_layer()
.socket()
.remote_endpoint()
.address()
.to_string();
}
catch (std::exception const&)
{
return {};
}
} }
~PlainWsSession() = default; ~PlainWsSession() = default;
@@ -97,11 +87,13 @@ class WsUpgrader : public std::enable_shared_from_this<WsUpgrader>
RPC::Counters& counters_; RPC::Counters& counters_;
WorkQueue& queue_; WorkQueue& queue_;
http::request<http::string_body> req_; http::request<http::string_body> req_;
std::optional<std::string> ip_;
public: public:
WsUpgrader( WsUpgrader(
boost::asio::io_context& ioc, boost::asio::io_context& ioc,
boost::asio::ip::tcp::socket&& socket, boost::asio::ip::tcp::socket&& socket,
std::optional<std::string> ip,
std::shared_ptr<BackendInterface const> backend, std::shared_ptr<BackendInterface const> backend,
std::shared_ptr<SubscriptionManager> subscriptions, std::shared_ptr<SubscriptionManager> subscriptions,
std::shared_ptr<ETLLoadBalancer> balancer, std::shared_ptr<ETLLoadBalancer> balancer,
@@ -122,11 +114,13 @@ public:
, dosGuard_(dosGuard) , dosGuard_(dosGuard)
, counters_(counters) , counters_(counters)
, queue_(queue) , queue_(queue)
, ip_(ip)
{ {
} }
WsUpgrader( WsUpgrader(
boost::asio::io_context& ioc, boost::asio::io_context& ioc,
boost::beast::tcp_stream&& stream, boost::beast::tcp_stream&& stream,
std::optional<std::string> ip,
std::shared_ptr<BackendInterface const> backend, std::shared_ptr<BackendInterface const> backend,
std::shared_ptr<SubscriptionManager> subscriptions, std::shared_ptr<SubscriptionManager> subscriptions,
std::shared_ptr<ETLLoadBalancer> balancer, std::shared_ptr<ETLLoadBalancer> balancer,
@@ -149,6 +143,7 @@ public:
, counters_(counters) , counters_(counters)
, queue_(queue) , queue_(queue)
, req_(std::move(req)) , req_(std::move(req))
, ip_(ip)
{ {
} }
@@ -197,6 +192,7 @@ private:
std::make_shared<PlainWsSession>( std::make_shared<PlainWsSession>(
ioc_, ioc_,
http_.release_socket(), http_.release_socket(),
ip_,
backend_, backend_,
subscriptions_, subscriptions_,
balancer_, balancer_,

View File

@@ -12,6 +12,7 @@ class SslHttpSession : public HttpBase<SslHttpSession>,
public std::enable_shared_from_this<SslHttpSession> public std::enable_shared_from_this<SslHttpSession>
{ {
boost::beast::ssl_stream<boost::beast::tcp_stream> stream_; boost::beast::ssl_stream<boost::beast::tcp_stream> stream_;
std::optional<std::string> ip_;
public: public:
// Take ownership of the socket // Take ownership of the socket
@@ -41,6 +42,25 @@ public:
std::move(buffer)) std::move(buffer))
, stream_(std::move(socket), ctx) , stream_(std::move(socket), ctx)
{ {
try
{
ip_ = stream_.next_layer()
.socket()
.remote_endpoint()
.address()
.to_string();
}
catch (std::exception const&)
{
}
if (ip_)
HttpBase::dosGuard().increment(*ip_);
}
~SslHttpSession()
{
if (ip_ and not upgraded_)
HttpBase::dosGuard().decrement(*ip_);
} }
boost::beast::ssl_stream<boost::beast::tcp_stream>& boost::beast::ssl_stream<boost::beast::tcp_stream>&
@@ -57,18 +77,7 @@ public:
std::optional<std::string> std::optional<std::string>
ip() ip()
{ {
try return ip_;
{
return stream_.next_layer()
.socket()
.remote_endpoint()
.address()
.to_string();
}
catch (std::exception const&)
{
return {};
}
} }
// Start the asynchronous operation // Start the asynchronous operation

View File

@@ -29,6 +29,7 @@ public:
explicit SslWsSession( explicit SslWsSession(
boost::asio::io_context& ioc, boost::asio::io_context& ioc,
boost::beast::ssl_stream<boost::beast::tcp_stream>&& stream, boost::beast::ssl_stream<boost::beast::tcp_stream>&& stream,
std::optional<std::string> ip,
std::shared_ptr<BackendInterface const> backend, std::shared_ptr<BackendInterface const> backend,
std::shared_ptr<SubscriptionManager> subscriptions, std::shared_ptr<SubscriptionManager> subscriptions,
std::shared_ptr<ETLLoadBalancer> balancer, std::shared_ptr<ETLLoadBalancer> balancer,
@@ -40,6 +41,7 @@ public:
boost::beast::flat_buffer&& b) boost::beast::flat_buffer&& b)
: WsSession( : WsSession(
ioc, ioc,
ip,
backend, backend,
subscriptions, subscriptions,
balancer, balancer,
@@ -62,20 +64,7 @@ public:
std::optional<std::string> std::optional<std::string>
ip() ip()
{ {
try return ip_;
{
return ws()
.next_layer()
.next_layer()
.socket()
.remote_endpoint()
.address()
.to_string();
}
catch (std::exception const&)
{
return {};
}
} }
}; };
@@ -85,6 +74,7 @@ class SslWsUpgrader : public std::enable_shared_from_this<SslWsUpgrader>
boost::beast::ssl_stream<boost::beast::tcp_stream> https_; boost::beast::ssl_stream<boost::beast::tcp_stream> https_;
boost::optional<http::request_parser<http::string_body>> parser_; boost::optional<http::request_parser<http::string_body>> parser_;
boost::beast::flat_buffer buffer_; boost::beast::flat_buffer buffer_;
std::optional<std::string> ip_;
std::shared_ptr<BackendInterface const> backend_; std::shared_ptr<BackendInterface const> backend_;
std::shared_ptr<SubscriptionManager> subscriptions_; std::shared_ptr<SubscriptionManager> subscriptions_;
std::shared_ptr<ETLLoadBalancer> balancer_; std::shared_ptr<ETLLoadBalancer> balancer_;
@@ -98,6 +88,7 @@ class SslWsUpgrader : public std::enable_shared_from_this<SslWsUpgrader>
public: public:
SslWsUpgrader( SslWsUpgrader(
boost::asio::io_context& ioc, boost::asio::io_context& ioc,
std::optional<std::string> ip,
boost::asio::ip::tcp::socket&& socket, boost::asio::ip::tcp::socket&& socket,
ssl::context& ctx, ssl::context& ctx,
std::shared_ptr<BackendInterface const> backend, std::shared_ptr<BackendInterface const> backend,
@@ -112,6 +103,7 @@ public:
: ioc_(ioc) : ioc_(ioc)
, https_(std::move(socket), ctx) , https_(std::move(socket), ctx)
, buffer_(std::move(b)) , buffer_(std::move(b))
, ip_(ip)
, backend_(backend) , backend_(backend)
, subscriptions_(subscriptions) , subscriptions_(subscriptions)
, balancer_(balancer) , balancer_(balancer)
@@ -125,6 +117,7 @@ public:
SslWsUpgrader( SslWsUpgrader(
boost::asio::io_context& ioc, boost::asio::io_context& ioc,
boost::beast::ssl_stream<boost::beast::tcp_stream> stream, boost::beast::ssl_stream<boost::beast::tcp_stream> stream,
std::optional<std::string> ip,
std::shared_ptr<BackendInterface const> backend, std::shared_ptr<BackendInterface const> backend,
std::shared_ptr<SubscriptionManager> subscriptions, std::shared_ptr<SubscriptionManager> subscriptions,
std::shared_ptr<ETLLoadBalancer> balancer, std::shared_ptr<ETLLoadBalancer> balancer,
@@ -138,6 +131,7 @@ public:
: ioc_(ioc) : ioc_(ioc)
, https_(std::move(stream)) , https_(std::move(stream))
, buffer_(std::move(b)) , buffer_(std::move(b))
, ip_(ip)
, backend_(backend) , backend_(backend)
, subscriptions_(subscriptions) , subscriptions_(subscriptions)
, balancer_(balancer) , balancer_(balancer)
@@ -210,6 +204,7 @@ private:
std::make_shared<SslWsSession>( std::make_shared<SslWsSession>(
ioc_, ioc_,
std::move(https_), std::move(https_),
ip_,
backend_, backend_,
subscriptions_, subscriptions_,
balancer_, balancer_,

View File

@@ -111,6 +111,9 @@ class WsSession : public WsBase,
bool sending_ = false; bool sending_ = false;
std::queue<std::shared_ptr<Message>> messages_; std::queue<std::shared_ptr<Message>> messages_;
protected:
std::optional<std::string> ip_;
void void
wsFail(boost::beast::error_code ec, char const* what) wsFail(boost::beast::error_code ec, char const* what)
{ {
@@ -128,6 +131,7 @@ class WsSession : public WsBase,
public: public:
explicit WsSession( explicit WsSession(
boost::asio::io_context& ioc, boost::asio::io_context& ioc,
std::optional<std::string> ip,
std::shared_ptr<BackendInterface const> backend, std::shared_ptr<BackendInterface const> backend,
std::shared_ptr<SubscriptionManager> subscriptions, std::shared_ptr<SubscriptionManager> subscriptions,
std::shared_ptr<ETLLoadBalancer> balancer, std::shared_ptr<ETLLoadBalancer> balancer,
@@ -148,12 +152,16 @@ public:
, dosGuard_(dosGuard) , dosGuard_(dosGuard)
, counters_(counters) , counters_(counters)
, queue_(queue) , queue_(queue)
, ip_(ip)
{ {
perfLog_.info() << tag() << "session created"; perfLog_.info() << tag() << "session created";
} }
virtual ~WsSession() virtual ~WsSession()
{ {
perfLog_.info() << tag() << "session closed"; perfLog_.info() << tag() << "session closed";
if (ip_)
dosGuard_.decrement(*ip_);
} }
// Access the derived class, this is part of // Access the derived class, this is part of