add connection counting (#433)

This commit is contained in:
CJ Cobb
2022-12-07 13:12:38 -05:00
committed by GitHub
parent 3ec5755930
commit 8a1f00debb
8 changed files with 120 additions and 53 deletions

View File

@@ -12,16 +12,19 @@ class DOSGuard
boost::asio::io_context& ctx_;
std::mutex mtx_; // protects ipFetchCount_
std::unordered_map<std::string, std::uint32_t> ipFetchCount_;
std::unordered_map<std::string, std::uint32_t> ipConnCount_;
std::unordered_set<std::string> const whitelist_;
std::uint32_t const maxFetches_;
std::uint32_t const sweepInterval_;
std::uint32_t const maxConnCount_;
public:
DOSGuard(clio::Config const& config, boost::asio::io_context& ctx)
: ctx_{ctx}
, whitelist_{getWhitelist(config)}
, maxFetches_{config.valueOr("dos_guard.max_fetches", 100u)}
, sweepInterval_{config.valueOr("dos_guard.sweep_interval", 1u)}
, maxFetches_{config.valueOr("dos_guard.max_fetches", 100000000u)}
, sweepInterval_{config.valueOr("dos_guard.sweep_interval", 10u)}
, maxConnCount_{config.valueOr("dos_guard.max_connections", 1u)}
{
createTimer();
}
@@ -53,11 +56,44 @@ public:
return true;
std::unique_lock lck(mtx_);
auto it = ipFetchCount_.find(ip);
if (it == ipFetchCount_.end())
return true;
bool fetchesOk = true;
bool connsOk = true;
{
auto it = ipFetchCount_.find(ip);
if (it != ipFetchCount_.end())
fetchesOk = it->second <= maxFetches_;
}
{
auto it = ipConnCount_.find(ip);
assert(it != ipConnCount_.end());
if (it != ipConnCount_.end())
{
connsOk = it->second <= maxConnCount_;
}
}
return it->second < maxFetches_;
return fetchesOk && connsOk;
}
void
increment(std::string const& ip)
{
if (whitelist_.contains(ip))
return;
std::unique_lock lck{mtx_};
ipConnCount_[ip]++;
}
void
decrement(std::string const& ip)
{
if (whitelist_.contains(ip))
return;
std::unique_lock lck{mtx_};
assert(ipConnCount_[ip] > 0);
ipConnCount_[ip]--;
if (ipConnCount_[ip] == 0)
ipConnCount_.erase(ip);
}
bool

View File

@@ -104,6 +104,7 @@ class HttpBase : public util::Taggable
protected:
clio::Logger perfLog_{"Performance"};
boost::beast::flat_buffer buffer_;
bool upgraded_ = false;
bool
dead()
@@ -171,11 +172,18 @@ public:
{
perfLog_.debug() << tag() << "http session created";
}
virtual ~HttpBase()
{
perfLog_.debug() << tag() << "http session closed";
}
DOSGuard&
dosGuard()
{
return dosGuard_;
}
void
do_read()
{
@@ -212,12 +220,14 @@ public:
if (boost::beast::websocket::is_upgrade(req_))
{
upgraded_ = true;
// Disable the timeout.
// The websocket::stream uses its own timeout settings.
boost::beast::get_lowest_layer(derived().stream()).expires_never();
return make_websocket_session(
ioc_,
derived().release_stream(),
derived().ip(),
std::move(req_),
std::move(buffer_),
backend_,

View File

@@ -12,6 +12,7 @@ class HttpSession : public HttpBase<HttpSession>,
public std::enable_shared_from_this<HttpSession>
{
boost::beast::tcp_stream stream_;
std::optional<std::string> ip_;
public:
// Take ownership of the socket
@@ -40,6 +41,21 @@ public:
std::move(buffer))
, stream_(std::move(socket))
{
try
{
ip_ = stream_.socket().remote_endpoint().address().to_string();
}
catch (std::exception const&)
{
}
if (ip_)
HttpBase::dosGuard().increment(*ip_);
}
~HttpSession()
{
if (ip_ and not upgraded_)
HttpBase::dosGuard().decrement(*ip_);
}
boost::beast::tcp_stream&
@@ -56,14 +72,7 @@ public:
std::optional<std::string>
ip()
{
try
{
return stream_.socket().remote_endpoint().address().to_string();
}
catch (std::exception const&)
{
return {};
}
return ip_;
}
// Start the asynchronous operation

View File

@@ -137,6 +137,7 @@ void
make_websocket_session(
boost::asio::io_context& ioc,
boost::beast::tcp_stream stream,
std::optional<std::string> const& ip,
http::request<http::string_body> req,
boost::beast::flat_buffer buffer,
std::shared_ptr<BackendInterface const> backend,
@@ -151,6 +152,7 @@ make_websocket_session(
std::make_shared<WsUpgrader>(
ioc,
std::move(stream),
ip,
backend,
subscriptions,
balancer,
@@ -168,6 +170,7 @@ void
make_websocket_session(
boost::asio::io_context& ioc,
boost::beast::ssl_stream<boost::beast::tcp_stream> stream,
std::optional<std::string> const& ip,
http::request<http::string_body> req,
boost::beast::flat_buffer buffer,
std::shared_ptr<BackendInterface const> backend,
@@ -182,6 +185,7 @@ make_websocket_session(
std::make_shared<SslWsUpgrader>(
ioc,
std::move(stream),
ip,
backend,
subscriptions,
balancer,

View File

@@ -31,6 +31,7 @@ public:
explicit PlainWsSession(
boost::asio::io_context& ioc,
boost::asio::ip::tcp::socket&& socket,
std::optional<std::string> ip,
std::shared_ptr<BackendInterface const> backend,
std::shared_ptr<SubscriptionManager> subscriptions,
std::shared_ptr<ETLLoadBalancer> balancer,
@@ -42,6 +43,7 @@ public:
boost::beast::flat_buffer&& buffer)
: WsSession(
ioc,
ip,
backend,
subscriptions,
balancer,
@@ -64,19 +66,7 @@ public:
std::optional<std::string>
ip()
{
try
{
return ws()
.next_layer()
.socket()
.remote_endpoint()
.address()
.to_string();
}
catch (std::exception const&)
{
return {};
}
return ip_;
}
~PlainWsSession() = default;
@@ -97,11 +87,13 @@ class WsUpgrader : public std::enable_shared_from_this<WsUpgrader>
RPC::Counters& counters_;
WorkQueue& queue_;
http::request<http::string_body> req_;
std::optional<std::string> ip_;
public:
WsUpgrader(
boost::asio::io_context& ioc,
boost::asio::ip::tcp::socket&& socket,
std::optional<std::string> ip,
std::shared_ptr<BackendInterface const> backend,
std::shared_ptr<SubscriptionManager> subscriptions,
std::shared_ptr<ETLLoadBalancer> balancer,
@@ -122,11 +114,13 @@ public:
, dosGuard_(dosGuard)
, counters_(counters)
, queue_(queue)
, ip_(ip)
{
}
WsUpgrader(
boost::asio::io_context& ioc,
boost::beast::tcp_stream&& stream,
std::optional<std::string> ip,
std::shared_ptr<BackendInterface const> backend,
std::shared_ptr<SubscriptionManager> subscriptions,
std::shared_ptr<ETLLoadBalancer> balancer,
@@ -149,6 +143,7 @@ public:
, counters_(counters)
, queue_(queue)
, req_(std::move(req))
, ip_(ip)
{
}
@@ -197,6 +192,7 @@ private:
std::make_shared<PlainWsSession>(
ioc_,
http_.release_socket(),
ip_,
backend_,
subscriptions_,
balancer_,

View File

@@ -12,6 +12,7 @@ class SslHttpSession : public HttpBase<SslHttpSession>,
public std::enable_shared_from_this<SslHttpSession>
{
boost::beast::ssl_stream<boost::beast::tcp_stream> stream_;
std::optional<std::string> ip_;
public:
// Take ownership of the socket
@@ -41,6 +42,25 @@ public:
std::move(buffer))
, stream_(std::move(socket), ctx)
{
try
{
ip_ = stream_.next_layer()
.socket()
.remote_endpoint()
.address()
.to_string();
}
catch (std::exception const&)
{
}
if (ip_)
HttpBase::dosGuard().increment(*ip_);
}
~SslHttpSession()
{
if (ip_ and not upgraded_)
HttpBase::dosGuard().decrement(*ip_);
}
boost::beast::ssl_stream<boost::beast::tcp_stream>&
@@ -57,18 +77,7 @@ public:
std::optional<std::string>
ip()
{
try
{
return stream_.next_layer()
.socket()
.remote_endpoint()
.address()
.to_string();
}
catch (std::exception const&)
{
return {};
}
return ip_;
}
// Start the asynchronous operation

View File

@@ -29,6 +29,7 @@ public:
explicit SslWsSession(
boost::asio::io_context& ioc,
boost::beast::ssl_stream<boost::beast::tcp_stream>&& stream,
std::optional<std::string> ip,
std::shared_ptr<BackendInterface const> backend,
std::shared_ptr<SubscriptionManager> subscriptions,
std::shared_ptr<ETLLoadBalancer> balancer,
@@ -40,6 +41,7 @@ public:
boost::beast::flat_buffer&& b)
: WsSession(
ioc,
ip,
backend,
subscriptions,
balancer,
@@ -62,20 +64,7 @@ public:
std::optional<std::string>
ip()
{
try
{
return ws()
.next_layer()
.next_layer()
.socket()
.remote_endpoint()
.address()
.to_string();
}
catch (std::exception const&)
{
return {};
}
return ip_;
}
};
@@ -85,6 +74,7 @@ class SslWsUpgrader : public std::enable_shared_from_this<SslWsUpgrader>
boost::beast::ssl_stream<boost::beast::tcp_stream> https_;
boost::optional<http::request_parser<http::string_body>> parser_;
boost::beast::flat_buffer buffer_;
std::optional<std::string> ip_;
std::shared_ptr<BackendInterface const> backend_;
std::shared_ptr<SubscriptionManager> subscriptions_;
std::shared_ptr<ETLLoadBalancer> balancer_;
@@ -98,6 +88,7 @@ class SslWsUpgrader : public std::enable_shared_from_this<SslWsUpgrader>
public:
SslWsUpgrader(
boost::asio::io_context& ioc,
std::optional<std::string> ip,
boost::asio::ip::tcp::socket&& socket,
ssl::context& ctx,
std::shared_ptr<BackendInterface const> backend,
@@ -112,6 +103,7 @@ public:
: ioc_(ioc)
, https_(std::move(socket), ctx)
, buffer_(std::move(b))
, ip_(ip)
, backend_(backend)
, subscriptions_(subscriptions)
, balancer_(balancer)
@@ -125,6 +117,7 @@ public:
SslWsUpgrader(
boost::asio::io_context& ioc,
boost::beast::ssl_stream<boost::beast::tcp_stream> stream,
std::optional<std::string> ip,
std::shared_ptr<BackendInterface const> backend,
std::shared_ptr<SubscriptionManager> subscriptions,
std::shared_ptr<ETLLoadBalancer> balancer,
@@ -138,6 +131,7 @@ public:
: ioc_(ioc)
, https_(std::move(stream))
, buffer_(std::move(b))
, ip_(ip)
, backend_(backend)
, subscriptions_(subscriptions)
, balancer_(balancer)
@@ -210,6 +204,7 @@ private:
std::make_shared<SslWsSession>(
ioc_,
std::move(https_),
ip_,
backend_,
subscriptions_,
balancer_,

View File

@@ -111,6 +111,9 @@ class WsSession : public WsBase,
bool sending_ = false;
std::queue<std::shared_ptr<Message>> messages_;
protected:
std::optional<std::string> ip_;
void
wsFail(boost::beast::error_code ec, char const* what)
{
@@ -128,6 +131,7 @@ class WsSession : public WsBase,
public:
explicit WsSession(
boost::asio::io_context& ioc,
std::optional<std::string> ip,
std::shared_ptr<BackendInterface const> backend,
std::shared_ptr<SubscriptionManager> subscriptions,
std::shared_ptr<ETLLoadBalancer> balancer,
@@ -148,12 +152,16 @@ public:
, dosGuard_(dosGuard)
, counters_(counters)
, queue_(queue)
, ip_(ip)
{
perfLog_.info() << tag() << "session created";
}
virtual ~WsSession()
{
perfLog_.info() << tag() << "session closed";
if (ip_)
dosGuard_.decrement(*ip_);
}
// Access the derived class, this is part of