Support whitelisting for IPV4/IPV6 with CIDR (#796)

Fixes #244
This commit is contained in:
Peter Chen
2023-08-08 11:04:16 -04:00
committed by GitHub
parent 5411fd7497
commit fc1b5ae4da
17 changed files with 364 additions and 68 deletions

View File

@@ -179,7 +179,8 @@ try
// Rate limiter, to prevent abuse
auto sweepHandler = IntervalSweepHandler{config, ioc};
auto dosGuard = DOSGuard{config, sweepHandler};
auto whitelistHandler = WhitelistHandler{config};
auto dosGuard = DOSGuard{config, whitelistHandler, sweepHandler};
// Interface to the database
auto backend = Backend::make_Backend(ioc, config);

View File

@@ -19,15 +19,16 @@
#pragma once
#include <config/Config.h>
#include <webserver/impl/WhitelistHandler.h>
#include <boost/asio.hpp>
#include <boost/iterator/transform_iterator.hpp>
#include <boost/system/error_code.hpp>
#include <algorithm>
#include <chrono>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <config/Config.h>
#include <ctime>
namespace clio {
@@ -43,9 +44,10 @@ public:
/**
* @brief A simple denial of service guard used for rate limiting.
*
* @tparam SweepHandler Type of the sweep handler
* @tparam Type of the Whitelist Handler
* @tparam Type of the Sweep Handler
*/
template <typename SweepHandler>
template <typename WhitelistHandlerType, typename SweepHandlerType>
class BasicDOSGuard : public BaseDOSGuard
{
// Accumulated state per IP, state will be reset accordingly
@@ -61,7 +63,7 @@ class BasicDOSGuard : public BaseDOSGuard
// accumulated states map
std::unordered_map<std::string, ClientState> ipState_;
std::unordered_map<std::string, std::uint32_t> ipConnCount_;
std::unordered_set<std::string> const whitelist_;
std::reference_wrapper<WhitelistHandlerType const> whitelistHandler_;
std::uint32_t const maxFetches_;
std::uint32_t const maxConnCount_;
@@ -73,10 +75,14 @@ public:
* @brief Constructs a new DOS guard.
*
* @param config Clio config
* @param sweepHandler Sweep handler that implements the sweeping behaviour
* @param WhitelistHandlerType Whitelist handler that checks whitelist for ip addresses
* @param SweepHandlerType Sweep handler that implements the sweeping behaviour
*/
BasicDOSGuard(clio::Config const& config, SweepHandler& sweepHandler)
: whitelist_{getWhitelist(config)}
BasicDOSGuard(
clio::Config const& config,
WhitelistHandlerType const& whitelistHandler,
SweepHandlerType& sweepHandler)
: whitelistHandler_{std::cref(whitelistHandler)}
, maxFetches_{config.valueOr("dos_guard.max_fetches", 1000000u)}
, maxConnCount_{config.valueOr("dos_guard.max_connections", 20u)}
, maxRequestCount_{config.valueOr("dos_guard.max_requests", 20u)}
@@ -92,9 +98,9 @@ public:
* @return false
*/
[[nodiscard]] bool
isWhiteListed(std::string const& ip) const noexcept
isWhiteListed(std::string_view const ip) const noexcept
{
return whitelist_.contains(ip);
return whitelistHandler_.get().isWhiteListed(ip);
}
/**
@@ -107,7 +113,7 @@ public:
[[nodiscard]] bool
isOk(std::string const& ip) const noexcept
{
if (whitelist_.contains(ip))
if (whitelistHandler_.get().isWhiteListed(ip))
return true;
{
@@ -144,7 +150,7 @@ public:
void
increment(std::string const& ip) noexcept
{
if (whitelist_.contains(ip))
if (whitelistHandler_.get().isWhiteListed(ip))
return;
std::scoped_lock lck{mtx_};
ipConnCount_[ip]++;
@@ -158,7 +164,7 @@ public:
void
decrement(std::string const& ip) noexcept
{
if (whitelist_.contains(ip))
if (whitelistHandler_.get().isWhiteListed(ip))
return;
std::scoped_lock lck{mtx_};
assert(ipConnCount_[ip] > 0);
@@ -182,7 +188,7 @@ public:
[[maybe_unused]] bool
add(std::string const& ip, uint32_t numObjects) noexcept
{
if (whitelist_.contains(ip))
if (whitelistHandler_.get().isWhiteListed(ip))
return true;
{
@@ -207,7 +213,7 @@ public:
[[maybe_unused]] bool
request(std::string const& ip) noexcept
{
if (whitelist_.contains(ip))
if (whitelistHandler_.get().isWhiteListed(ip))
return true;
{
@@ -303,6 +309,6 @@ private:
}
};
using DOSGuard = BasicDOSGuard<IntervalSweepHandler>;
using DOSGuard = BasicDOSGuard<WhitelistHandler, IntervalSweepHandler>;
} // namespace clio

View File

@@ -20,7 +20,7 @@
#pragma once
#include <webserver/PlainWsSession.h>
#include <webserver/details/HttpBase.h>
#include <webserver/impl/HttpBase.h>
namespace Server {

View File

@@ -19,7 +19,7 @@
#pragma once
#include <webserver/details/WsBase.h>
#include <webserver/impl/WsBase.h>
namespace Server {

View File

@@ -25,7 +25,7 @@
#include <rpc/common/impl/APIVersionParser.h>
#include <util/JsonUtils.h>
#include <util/Profiler.h>
#include <webserver/details/ErrorHandling.h>
#include <webserver/impl/ErrorHandling.h>
#include <boost/json/parse.hpp>

View File

@@ -20,7 +20,7 @@
#pragma once
#include <webserver/SslWsSession.h>
#include <webserver/details/HttpBase.h>
#include <webserver/impl/HttpBase.h>
namespace Server {

View File

@@ -19,7 +19,7 @@
#pragma once
#include <webserver/details/WsBase.h>
#include <webserver/impl/WsBase.h>
namespace Server {

View File

@@ -0,0 +1,167 @@
//------------------------------------------------------------------------------
/*
This file is part of clio: https://github.com/XRPLF/clio
Copyright (c) 2023, the clio developers.
Permission to use, copy, modify, and distribute this software for any
purpose with or without fee is hereby granted, provided that the above
copyright notice and this permission notice appear in all copies.
THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
*/
//==============================================================================
#pragma once
#include <boost/asio.hpp>
#include <fmt/core.h>
#include <regex>
#include <string>
#include <unordered_map>
#include <unordered_set>
namespace clio {
/**
* @brief A whitelist to remove rate limits of certain IP addresses
*
*/
class Whitelist
{
std::vector<boost::asio::ip::network_v4> subnetsV4_;
std::vector<boost::asio::ip::network_v6> subnetsV6_;
std::vector<boost::asio::ip::address> ips_;
public:
/**
* @brief Add network address to whitelist
*
* @param net Network part of the ip address
* @throws std::runtime::error when the network address is not valid
*/
void
add(std::string_view net)
{
using namespace boost::asio;
if (not isMask(net))
{
ips_.push_back(ip::make_address(net));
return;
}
if (isV4(net))
subnetsV4_.push_back(ip::make_network_v4(net));
else if (isV6(net))
subnetsV6_.push_back(ip::make_network_v6(net));
else
throw std::runtime_error(fmt::format("malformed network: {}", net.data()));
}
/**
* @brief Checks to see if ip address is whitelisted
*
* @param ip IP address
* @throws std::runtime::error when the network address is not valid
*/
bool
isWhiteListed(std::string_view ip) const
{
using namespace boost::asio;
auto const addr = ip::make_address(ip);
if (std::find(std::begin(ips_), std::end(ips_), addr) != std::end(ips_))
return true;
if (addr.is_v4())
return std::find_if(
std::begin(subnetsV4_), std::end(subnetsV4_), std::bind_front(&isInV4Subnet, std::cref(addr))) !=
std::end(subnetsV4_);
if (addr.is_v6())
return std::find_if(
std::begin(subnetsV6_), std::end(subnetsV6_), std::bind_front(&isInV6Subnet, std::cref(addr))) !=
std::end(subnetsV6_);
return false;
}
private:
static bool
isInV4Subnet(boost::asio::ip::address const& addr, boost::asio::ip::network_v4 const& subnet)
{
auto const range = subnet.hosts();
return range.find(addr.to_v4()) != range.end();
}
static bool
isInV6Subnet(boost::asio::ip::address const& addr, boost::asio::ip::network_v6 const& subnet)
{
auto const range = subnet.hosts();
return range.find(addr.to_v6()) != range.end();
}
bool
isV4(std::string_view net) const
{
static const std::regex ipv4CidrRegex(R"(^\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}/\d{1,2}$)");
return std::regex_match(std::string(net), ipv4CidrRegex);
}
bool
isV6(std::string_view net) const
{
static const std::regex ipv6CidrRegex(R"(^([0-9A-Fa-f]{1,4}:){7}[0-9A-Fa-f]{1,4}/\d{1,3}$)");
return std::regex_match(std::string(net), ipv6CidrRegex);
}
bool
isMask(std::string_view net) const
{
return net.find('/') != std::string_view::npos;
}
};
/**
* @brief A simple handler to add/check elements in a whitelist
*
* @param arr map of net addresses to add to whitelist
*/
class WhitelistHandler
{
Whitelist whitelist_;
public:
WhitelistHandler(clio::Config const& config)
{
std::unordered_set<std::string> arr = getWhitelist(config);
for (auto const& net : arr)
whitelist_.add(net);
}
bool
isWhiteListed(std::string_view ip) const
{
return whitelist_.isWhiteListed(ip);
}
private:
[[nodiscard]] std::unordered_set<std::string> const
getWhitelist(clio::Config const& config) const
{
using T = std::unordered_set<std::string> const;
auto whitelist = config.arrayOr("dos_guard.whitelist", {});
auto const transform = [](auto const& elem) { return elem.template value<std::string>(); };
return T{
boost::transform_iterator(std::begin(whitelist), transform),
boost::transform_iterator(std::end(whitelist), transform)};
}
};
} // namespace clio