Advance DOSGuard (#96)

* Add C++20 features
* Make whitelist const
This commit is contained in:
undertome
2022-02-18 16:43:02 -06:00
committed by GitHub
parent 6567a5cc7d
commit c03b72ad51
17 changed files with 117 additions and 104 deletions

View File

@@ -1,5 +1,5 @@
set(CMAKE_CXX_EXTENSIONS OFF)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD 20)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-narrowing")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-narrowing -Wall -Werror -Wno-dangling-else")

View File

@@ -159,7 +159,6 @@ BackendInterface::fetchBookOffers(
BookOffersPage page;
const ripple::uint256 bookEnd = ripple::getQualityNext(book);
ripple::uint256 uTipIndex = book;
bool done = false;
std::vector<ripple::uint256> keys;
auto getMillis = [](auto diff) {
return std::chrono::duration_cast<std::chrono::milliseconds>(diff)

View File

@@ -106,9 +106,9 @@ struct WriteCallbackData
template <class T, class B>
struct BulkWriteCallbackData : public WriteCallbackData<T, B>
{
std::atomic_int& numRemaining;
std::mutex& mtx;
std::condition_variable& cv;
std::atomic_int& numRemaining;
BulkWriteCallbackData(
CassandraBackend const* b,
T&& d,
@@ -1006,7 +1006,6 @@ CassandraBackend::open(bool readOnly)
}
return true;
};
CassStatement* statement;
CassFuture* fut;
bool setupSessionAndTable = false;
while (!setupSessionAndTable)

View File

@@ -252,7 +252,7 @@ Pg::getSocket(boost::asio::yield_context& yield)
delete socket;
}};
return std::move(s);
return s;
}
PgResult
@@ -775,7 +775,6 @@ PgPool::PgPool(boost::asio::io_context& ioc, boost::json::object const& config)
if (config.contains("max_connections"))
config_.max_connections = config.at("max_connections").as_int64();
std::size_t timeout;
if (config.contains("timeout"))
config_.timeout =
std::chrono::seconds(config.at("timeout").as_uint64());
@@ -1111,7 +1110,7 @@ create table if not exists account_transactions10 partition of account_transacti
CREATE TABLE IF NOT EXISTS successor (
key bytea NOT NULL,
key bytea NOT NULL,
ledger_seq bigint NOT NULL,
next bytea NOT NULL,
PRIMARY KEY(key, ledger_seq)
@@ -1711,8 +1710,6 @@ getLedger(
"total_coins, closing_time, prev_closing_time, close_time_res, "
"close_flags, ledger_seq FROM ledgers ";
std::uint32_t expNumResults = 1;
if (auto ledgerSeq = std::get_if<std::uint32_t>(&whichLedger))
{
sql << "WHERE ledger_seq = " + std::to_string(*ledgerSeq);

View File

@@ -781,7 +781,6 @@ PostgresBackend::doOnlineDelete(
std::uint32_t minLedger = rng->maxSequence - numLedgersToKeep;
if (minLedger <= rng->minSequence)
return false;
std::uint32_t limit = 2048;
PgQuery pgQuery(pgPool_);
pgQuery("SET statement_timeout TO 0", yield);
std::optional<ripple::uint256> cursor;

View File

@@ -22,13 +22,13 @@ ETLSourceImpl<Derived>::ETLSourceImpl(
std::shared_ptr<SubscriptionManager> subscriptions,
std::shared_ptr<NetworkValidatedLedgers> networkValidatedLedgers,
ETLLoadBalancer& balancer)
: ioc_(ioContext)
, resolver_(boost::asio::make_strand(ioContext))
, timer_(ioContext)
: resolver_(boost::asio::make_strand(ioContext))
, networkValidatedLedgers_(networkValidatedLedgers)
, backend_(backend)
, subscriptions_(subscriptions)
, balancer_(balancer)
, ioc_(ioContext)
, timer_(ioContext)
{
if (config.contains("ip"))
{
@@ -852,6 +852,44 @@ ETLSourceImpl<Derived>::fetchLedger(
return {status, std::move(response)};
}
static std::unique_ptr<ETLSource>
make_ETLSource(
boost::json::object const& config,
boost::asio::io_context& ioContext,
std::optional<std::reference_wrapper<boost::asio::ssl::context>> sslCtx,
std::shared_ptr<BackendInterface> backend,
std::shared_ptr<SubscriptionManager> subscriptions,
std::shared_ptr<NetworkValidatedLedgers> networkValidatedLedgers,
ETLLoadBalancer& balancer)
{
std::unique_ptr<ETLSource> src = nullptr;
if (sslCtx)
{
src = std::make_unique<SslETLSource>(
config,
ioContext,
sslCtx,
backend,
subscriptions,
networkValidatedLedgers,
balancer);
}
else
{
src = std::make_unique<PlainETLSource>(
config,
ioContext,
backend,
subscriptions,
networkValidatedLedgers,
balancer);
}
src->run();
return src;
}
ETLLoadBalancer::ETLLoadBalancer(
boost::json::object const& config,
boost::asio::io_context& ioContext,
@@ -873,7 +911,7 @@ ETLLoadBalancer::ETLLoadBalancer(
for (auto& entry : config.at("etl_sources").as_array())
{
std::unique_ptr<ETLSource> source = ETL::make_ETLSource(
std::unique_ptr<ETLSource> source = make_ETLSource(
entry.as_object(),
ioContext,
sslCtx,
@@ -914,7 +952,7 @@ ETLLoadBalancer::fetchLedger(
{
org::xrpl::rpc::v1::GetLedgerResponse response;
bool success = execute(
[&response, ledgerSequence, getObjects, getObjectNeighbors, this](
[&response, ledgerSequence, getObjects, getObjectNeighbors](
auto& source) {
auto [status, data] = source->fetchLedger(
ledgerSequence, getObjects, getObjectNeighbors);
@@ -1015,7 +1053,7 @@ ETLSourceImpl<Derived>::forwardToRippled(
//
// https://github.com/ripple/rippled/blob/develop/cfg/rippled-example.cfg
ws->set_option(websocket::stream_base::decorator(
[&request, &clientIp](websocket::request_type& req) {
[&clientIp](websocket::request_type& req) {
req.set(
http::field::user_agent,
std::string(BOOST_BEAST_VERSION_STRING) +

View File

@@ -420,46 +420,6 @@ public:
}
};
namespace ETL {
static std::unique_ptr<ETLSource>
make_ETLSource(
boost::json::object const& config,
boost::asio::io_context& ioContext,
std::optional<std::reference_wrapper<boost::asio::ssl::context>> sslCtx,
std::shared_ptr<BackendInterface> backend,
std::shared_ptr<SubscriptionManager> subscriptions,
std::shared_ptr<NetworkValidatedLedgers> networkValidatedLedgers,
ETLLoadBalancer& balancer)
{
std::unique_ptr<ETLSource> src = nullptr;
if (sslCtx)
{
src = std::make_unique<SslETLSource>(
config,
ioContext,
sslCtx,
backend,
subscriptions,
networkValidatedLedgers,
balancer);
}
else
{
src = std::make_unique<PlainETLSource>(
config,
ioContext,
backend,
subscriptions,
networkValidatedLedgers,
balancer);
}
src->run();
return src;
}
} // namespace ETL
/// This class is used to manage connections to transaction processing processes
/// This class spawns a listener for each etl source, which listens to messages
/// on the ledgers stream (to keep track of which ledgers have been validated by

View File

@@ -434,7 +434,6 @@ ReportingETL::buildNextLedger(org::xrpl::rpc::v1::GetLedgerResponse& rawData)
BOOST_LOG_TRIVIAL(debug)
<< __func__ << " object neighbors not included. using cache";
assert(backend_->cache().isFull());
size_t idx = 0;
for (auto const& obj : cacheUpdates)
{
if (modified.count(obj.key))
@@ -952,11 +951,11 @@ ReportingETL::ReportingETL(
std::shared_ptr<SubscriptionManager> subscriptions,
std::shared_ptr<ETLLoadBalancer> balancer,
std::shared_ptr<NetworkValidatedLedgers> ledgers)
: publishStrand_(ioc)
, ioContext_(ioc)
, backend_(backend)
: backend_(backend)
, subscriptions_(subscriptions)
, loadBalancer_(balancer)
, ioContext_(ioc)
, publishStrand_(ioc)
, networkValidatedLedgers_(ledgers)
{
if (config.contains("start_sequence"))

View File

@@ -33,11 +33,11 @@ struct Context
std::uint32_t version;
boost::json::object const& params;
std::shared_ptr<BackendInterface const> const& backend;
std::shared_ptr<ETLLoadBalancer> const& balancer;
// this needs to be an actual shared_ptr, not a reference. The above
// references refer to shared_ptr members of WsBase, but WsBase contains
// SubscriptionManager as a weak_ptr, to prevent a shared_ptr cycle.
std::shared_ptr<SubscriptionManager> subscriptions;
std::shared_ptr<ETLLoadBalancer> const& balancer;
std::shared_ptr<WsBase> session;
Backend::LedgerRange const& range;
Counters& counters;

View File

@@ -67,10 +67,6 @@ doAccountInfo(Context const& context)
context.backend->fetchLedgerObject(key.key, lgrInfo.seq, context.yield);
auto end = std::chrono::system_clock::now();
auto time =
std::chrono::duration_cast<std::chrono::microseconds>(end - start)
.count();
if (!dbResponse)
{
return Status{Error::rpcACT_NOT_FOUND};

View File

@@ -72,10 +72,6 @@ doLedgerData(Context const& context)
auto end = std::chrono::system_clock::now();
auto time =
std::chrono::duration_cast<std::chrono::microseconds>(end - start)
.count();
boost::json::object header;
if (!cursor)
{

View File

@@ -347,9 +347,6 @@ doLedgerEntry(Context const& context)
auto dbResponse =
context.backend->fetchLedgerObject(key, lgrInfo.seq, context.yield);
auto end = std::chrono::system_clock::now();
auto time =
std::chrono::duration_cast<std::chrono::microseconds>(end - start)
.count();
if (!dbResponse or dbResponse->size() == 0)
return Status{Error::rpcLGR_NOT_FOUND};
@@ -372,4 +369,4 @@ doLedgerEntry(Context const& context)
return response;
}
} // namespace RPC
} // namespace RPC

View File

@@ -8,33 +8,63 @@
class DOSGuard
{
std::unordered_map<std::string, uint32_t> ipFetchCount_;
uint32_t maxFetches_ = 100;
uint32_t sweepInterval_ = 1;
std::unordered_set<std::string> whitelist_;
boost::asio::io_context& ctx_;
std::mutex mtx_;
std::mutex mtx_; // protects ipFetchCount_
std::unordered_map<std::string, std::uint32_t> ipFetchCount_;
std::unordered_set<std::string> const whitelist_;
std::uint32_t const maxFetches_;
std::uint32_t const sweepInterval_;
std::optional<boost::json::object>
getConfig(boost::json::object const& config) const
{
if (!config.contains("dos_guard"))
return {};
return config.at("dos_guard").as_object();
}
std::uint32_t
get(boost::json::object const& config,
std::string const& key,
std::uint32_t const fallback) const
{
if (auto const c = getConfig(config); c)
return c->at(key).as_int64();
return fallback;
}
std::unordered_set<std::string> const
getWhitelist(boost::json::object const& config) const
{
using T = std::unordered_set<std::string> const;
auto const& c = getConfig(config);
if (!c)
return T();
if (!c->contains("whitelist"))
return T();
auto const& w = c->at("whitelist").as_array();
auto const transform = [](auto const& elem) {
return std::string(elem.as_string().c_str());
};
return T(
boost::transform_iterator(w.begin(), transform),
boost::transform_iterator(w.end(), transform));
}
public:
DOSGuard(boost::json::object const& config, boost::asio::io_context& ctx)
: ctx_(ctx)
, whitelist_(getWhitelist(config))
, maxFetches_(get(config, "max_fetches", 100))
, sweepInterval_(get(config, "sweep_interval", 1))
{
if (config.contains("dos_guard"))
{
auto dosGuardConfig = config.at("dos_guard").as_object();
if (dosGuardConfig.contains("max_fetches") &&
dosGuardConfig.contains("sweep_interval"))
{
maxFetches_ = dosGuardConfig.at("max_fetches").as_int64();
sweepInterval_ = dosGuardConfig.at("sweep_interval").as_int64();
}
if (dosGuardConfig.contains("whitelist"))
{
auto whitelist = dosGuardConfig.at("whitelist").as_array();
for (auto& ip : whitelist)
whitelist_.insert(ip.as_string().c_str());
}
}
createTimer();
}
@@ -55,20 +85,23 @@ public:
bool
isOk(std::string const& ip)
{
if (whitelist_.count(ip) > 0)
if (whitelist_.contains(ip))
return true;
std::unique_lock lck(mtx_);
auto it = ipFetchCount_.find(ip);
if (it == ipFetchCount_.end())
return true;
return it->second < maxFetches_;
}
bool
add(std::string const& ip, uint32_t numObjects)
{
if (whitelist_.count(ip) > 0)
if (whitelist_.contains(ip))
return true;
{
std::unique_lock lck(mtx_);
auto it = ipFetchCount_.find(ip);
@@ -77,6 +110,7 @@ public:
else
it->second += numObjects;
}
return isOk(ip);
}

View File

@@ -95,12 +95,12 @@ public:
boost::beast::flat_buffer&& b)
: ioc_(ioc)
, http_(std::move(socket))
, buffer_(std::move(b))
, backend_(backend)
, subscriptions_(subscriptions)
, balancer_(balancer)
, dosGuard_(dosGuard)
, counters_(counters)
, buffer_(std::move(b))
{
}
WsUpgrader(
@@ -115,12 +115,12 @@ public:
http::request<http::string_body> req)
: ioc_(ioc)
, http_(std::move(stream))
, buffer_(std::move(b))
, backend_(backend)
, subscriptions_(subscriptions)
, balancer_(balancer)
, dosGuard_(dosGuard)
, counters_(counters)
, buffer_(std::move(b))
, req_(std::move(req))
{
}

View File

@@ -92,12 +92,12 @@ public:
boost::beast::flat_buffer&& b)
: ioc_(ioc)
, https_(std::move(socket), ctx)
, buffer_(std::move(b))
, backend_(backend)
, subscriptions_(subscriptions)
, balancer_(balancer)
, dosGuard_(dosGuard)
, counters_(counters)
, buffer_(std::move(b))
{
}
SslWsUpgrader(
@@ -112,12 +112,12 @@ public:
http::request<http::string_body> req)
: ioc_(ioc)
, https_(std::move(stream))
, buffer_(std::move(b))
, backend_(backend)
, subscriptions_(subscriptions)
, balancer_(balancer)
, dosGuard_(dosGuard)
, counters_(counters)
, buffer_(std::move(b))
, req_(std::move(req))
{
}

View File

@@ -101,13 +101,13 @@ public:
DOSGuard& dosGuard,
RPC::Counters& counters,
boost::beast::flat_buffer&& buffer)
: ioc_(ioc)
: buffer_(std::move(buffer))
, ioc_(ioc)
, backend_(backend)
, subscriptions_(subscriptions)
, balancer_(balancer)
, dosGuard_(dosGuard)
, counters_(counters)
, buffer_(std::move(buffer))
{
}
virtual ~WsSession()

View File

@@ -2203,7 +2203,6 @@ TEST(Backend, CacheIntegration)
for (auto obj : objs)
{
bool found = false;
bool correct = false;
for (auto retObj : retObjs)
{
if (ripple::strHex(obj.first) ==