Advance DOSGuard (#96)

* Add C++20 features
* Make whitelist const
This commit is contained in:
undertome
2022-02-18 16:43:02 -06:00
committed by GitHub
parent 6567a5cc7d
commit c03b72ad51
17 changed files with 117 additions and 104 deletions

View File

@@ -1,5 +1,5 @@
set(CMAKE_CXX_EXTENSIONS OFF) set(CMAKE_CXX_EXTENSIONS OFF)
set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD 20)
set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-narrowing") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-narrowing -Wall -Werror -Wno-dangling-else")

View File

@@ -159,7 +159,6 @@ BackendInterface::fetchBookOffers(
BookOffersPage page; BookOffersPage page;
const ripple::uint256 bookEnd = ripple::getQualityNext(book); const ripple::uint256 bookEnd = ripple::getQualityNext(book);
ripple::uint256 uTipIndex = book; ripple::uint256 uTipIndex = book;
bool done = false;
std::vector<ripple::uint256> keys; std::vector<ripple::uint256> keys;
auto getMillis = [](auto diff) { auto getMillis = [](auto diff) {
return std::chrono::duration_cast<std::chrono::milliseconds>(diff) return std::chrono::duration_cast<std::chrono::milliseconds>(diff)

View File

@@ -106,9 +106,9 @@ struct WriteCallbackData
template <class T, class B> template <class T, class B>
struct BulkWriteCallbackData : public WriteCallbackData<T, B> struct BulkWriteCallbackData : public WriteCallbackData<T, B>
{ {
std::atomic_int& numRemaining;
std::mutex& mtx; std::mutex& mtx;
std::condition_variable& cv; std::condition_variable& cv;
std::atomic_int& numRemaining;
BulkWriteCallbackData( BulkWriteCallbackData(
CassandraBackend const* b, CassandraBackend const* b,
T&& d, T&& d,
@@ -1006,7 +1006,6 @@ CassandraBackend::open(bool readOnly)
} }
return true; return true;
}; };
CassStatement* statement;
CassFuture* fut; CassFuture* fut;
bool setupSessionAndTable = false; bool setupSessionAndTable = false;
while (!setupSessionAndTable) while (!setupSessionAndTable)

View File

@@ -252,7 +252,7 @@ Pg::getSocket(boost::asio::yield_context& yield)
delete socket; delete socket;
}}; }};
return std::move(s); return s;
} }
PgResult PgResult
@@ -775,7 +775,6 @@ PgPool::PgPool(boost::asio::io_context& ioc, boost::json::object const& config)
if (config.contains("max_connections")) if (config.contains("max_connections"))
config_.max_connections = config.at("max_connections").as_int64(); config_.max_connections = config.at("max_connections").as_int64();
std::size_t timeout;
if (config.contains("timeout")) if (config.contains("timeout"))
config_.timeout = config_.timeout =
std::chrono::seconds(config.at("timeout").as_uint64()); std::chrono::seconds(config.at("timeout").as_uint64());
@@ -1111,7 +1110,7 @@ create table if not exists account_transactions10 partition of account_transacti
CREATE TABLE IF NOT EXISTS successor ( CREATE TABLE IF NOT EXISTS successor (
key bytea NOT NULL, key bytea NOT NULL,
ledger_seq bigint NOT NULL, ledger_seq bigint NOT NULL,
next bytea NOT NULL, next bytea NOT NULL,
PRIMARY KEY(key, ledger_seq) PRIMARY KEY(key, ledger_seq)
@@ -1711,8 +1710,6 @@ getLedger(
"total_coins, closing_time, prev_closing_time, close_time_res, " "total_coins, closing_time, prev_closing_time, close_time_res, "
"close_flags, ledger_seq FROM ledgers "; "close_flags, ledger_seq FROM ledgers ";
std::uint32_t expNumResults = 1;
if (auto ledgerSeq = std::get_if<std::uint32_t>(&whichLedger)) if (auto ledgerSeq = std::get_if<std::uint32_t>(&whichLedger))
{ {
sql << "WHERE ledger_seq = " + std::to_string(*ledgerSeq); sql << "WHERE ledger_seq = " + std::to_string(*ledgerSeq);

View File

@@ -781,7 +781,6 @@ PostgresBackend::doOnlineDelete(
std::uint32_t minLedger = rng->maxSequence - numLedgersToKeep; std::uint32_t minLedger = rng->maxSequence - numLedgersToKeep;
if (minLedger <= rng->minSequence) if (minLedger <= rng->minSequence)
return false; return false;
std::uint32_t limit = 2048;
PgQuery pgQuery(pgPool_); PgQuery pgQuery(pgPool_);
pgQuery("SET statement_timeout TO 0", yield); pgQuery("SET statement_timeout TO 0", yield);
std::optional<ripple::uint256> cursor; std::optional<ripple::uint256> cursor;

View File

@@ -22,13 +22,13 @@ ETLSourceImpl<Derived>::ETLSourceImpl(
std::shared_ptr<SubscriptionManager> subscriptions, std::shared_ptr<SubscriptionManager> subscriptions,
std::shared_ptr<NetworkValidatedLedgers> networkValidatedLedgers, std::shared_ptr<NetworkValidatedLedgers> networkValidatedLedgers,
ETLLoadBalancer& balancer) ETLLoadBalancer& balancer)
: ioc_(ioContext) : resolver_(boost::asio::make_strand(ioContext))
, resolver_(boost::asio::make_strand(ioContext))
, timer_(ioContext)
, networkValidatedLedgers_(networkValidatedLedgers) , networkValidatedLedgers_(networkValidatedLedgers)
, backend_(backend) , backend_(backend)
, subscriptions_(subscriptions) , subscriptions_(subscriptions)
, balancer_(balancer) , balancer_(balancer)
, ioc_(ioContext)
, timer_(ioContext)
{ {
if (config.contains("ip")) if (config.contains("ip"))
{ {
@@ -852,6 +852,44 @@ ETLSourceImpl<Derived>::fetchLedger(
return {status, std::move(response)}; return {status, std::move(response)};
} }
static std::unique_ptr<ETLSource>
make_ETLSource(
boost::json::object const& config,
boost::asio::io_context& ioContext,
std::optional<std::reference_wrapper<boost::asio::ssl::context>> sslCtx,
std::shared_ptr<BackendInterface> backend,
std::shared_ptr<SubscriptionManager> subscriptions,
std::shared_ptr<NetworkValidatedLedgers> networkValidatedLedgers,
ETLLoadBalancer& balancer)
{
std::unique_ptr<ETLSource> src = nullptr;
if (sslCtx)
{
src = std::make_unique<SslETLSource>(
config,
ioContext,
sslCtx,
backend,
subscriptions,
networkValidatedLedgers,
balancer);
}
else
{
src = std::make_unique<PlainETLSource>(
config,
ioContext,
backend,
subscriptions,
networkValidatedLedgers,
balancer);
}
src->run();
return src;
}
ETLLoadBalancer::ETLLoadBalancer( ETLLoadBalancer::ETLLoadBalancer(
boost::json::object const& config, boost::json::object const& config,
boost::asio::io_context& ioContext, boost::asio::io_context& ioContext,
@@ -873,7 +911,7 @@ ETLLoadBalancer::ETLLoadBalancer(
for (auto& entry : config.at("etl_sources").as_array()) for (auto& entry : config.at("etl_sources").as_array())
{ {
std::unique_ptr<ETLSource> source = ETL::make_ETLSource( std::unique_ptr<ETLSource> source = make_ETLSource(
entry.as_object(), entry.as_object(),
ioContext, ioContext,
sslCtx, sslCtx,
@@ -914,7 +952,7 @@ ETLLoadBalancer::fetchLedger(
{ {
org::xrpl::rpc::v1::GetLedgerResponse response; org::xrpl::rpc::v1::GetLedgerResponse response;
bool success = execute( bool success = execute(
[&response, ledgerSequence, getObjects, getObjectNeighbors, this]( [&response, ledgerSequence, getObjects, getObjectNeighbors](
auto& source) { auto& source) {
auto [status, data] = source->fetchLedger( auto [status, data] = source->fetchLedger(
ledgerSequence, getObjects, getObjectNeighbors); ledgerSequence, getObjects, getObjectNeighbors);
@@ -1015,7 +1053,7 @@ ETLSourceImpl<Derived>::forwardToRippled(
// //
// https://github.com/ripple/rippled/blob/develop/cfg/rippled-example.cfg // https://github.com/ripple/rippled/blob/develop/cfg/rippled-example.cfg
ws->set_option(websocket::stream_base::decorator( ws->set_option(websocket::stream_base::decorator(
[&request, &clientIp](websocket::request_type& req) { [&clientIp](websocket::request_type& req) {
req.set( req.set(
http::field::user_agent, http::field::user_agent,
std::string(BOOST_BEAST_VERSION_STRING) + std::string(BOOST_BEAST_VERSION_STRING) +

View File

@@ -420,46 +420,6 @@ public:
} }
}; };
namespace ETL {
static std::unique_ptr<ETLSource>
make_ETLSource(
boost::json::object const& config,
boost::asio::io_context& ioContext,
std::optional<std::reference_wrapper<boost::asio::ssl::context>> sslCtx,
std::shared_ptr<BackendInterface> backend,
std::shared_ptr<SubscriptionManager> subscriptions,
std::shared_ptr<NetworkValidatedLedgers> networkValidatedLedgers,
ETLLoadBalancer& balancer)
{
std::unique_ptr<ETLSource> src = nullptr;
if (sslCtx)
{
src = std::make_unique<SslETLSource>(
config,
ioContext,
sslCtx,
backend,
subscriptions,
networkValidatedLedgers,
balancer);
}
else
{
src = std::make_unique<PlainETLSource>(
config,
ioContext,
backend,
subscriptions,
networkValidatedLedgers,
balancer);
}
src->run();
return src;
}
} // namespace ETL
/// This class is used to manage connections to transaction processing processes /// This class is used to manage connections to transaction processing processes
/// This class spawns a listener for each etl source, which listens to messages /// This class spawns a listener for each etl source, which listens to messages
/// on the ledgers stream (to keep track of which ledgers have been validated by /// on the ledgers stream (to keep track of which ledgers have been validated by

View File

@@ -434,7 +434,6 @@ ReportingETL::buildNextLedger(org::xrpl::rpc::v1::GetLedgerResponse& rawData)
BOOST_LOG_TRIVIAL(debug) BOOST_LOG_TRIVIAL(debug)
<< __func__ << " object neighbors not included. using cache"; << __func__ << " object neighbors not included. using cache";
assert(backend_->cache().isFull()); assert(backend_->cache().isFull());
size_t idx = 0;
for (auto const& obj : cacheUpdates) for (auto const& obj : cacheUpdates)
{ {
if (modified.count(obj.key)) if (modified.count(obj.key))
@@ -952,11 +951,11 @@ ReportingETL::ReportingETL(
std::shared_ptr<SubscriptionManager> subscriptions, std::shared_ptr<SubscriptionManager> subscriptions,
std::shared_ptr<ETLLoadBalancer> balancer, std::shared_ptr<ETLLoadBalancer> balancer,
std::shared_ptr<NetworkValidatedLedgers> ledgers) std::shared_ptr<NetworkValidatedLedgers> ledgers)
: publishStrand_(ioc) : backend_(backend)
, ioContext_(ioc)
, backend_(backend)
, subscriptions_(subscriptions) , subscriptions_(subscriptions)
, loadBalancer_(balancer) , loadBalancer_(balancer)
, ioContext_(ioc)
, publishStrand_(ioc)
, networkValidatedLedgers_(ledgers) , networkValidatedLedgers_(ledgers)
{ {
if (config.contains("start_sequence")) if (config.contains("start_sequence"))

View File

@@ -33,11 +33,11 @@ struct Context
std::uint32_t version; std::uint32_t version;
boost::json::object const& params; boost::json::object const& params;
std::shared_ptr<BackendInterface const> const& backend; std::shared_ptr<BackendInterface const> const& backend;
std::shared_ptr<ETLLoadBalancer> const& balancer;
// this needs to be an actual shared_ptr, not a reference. The above // this needs to be an actual shared_ptr, not a reference. The above
// references refer to shared_ptr members of WsBase, but WsBase contains // references refer to shared_ptr members of WsBase, but WsBase contains
// SubscriptionManager as a weak_ptr, to prevent a shared_ptr cycle. // SubscriptionManager as a weak_ptr, to prevent a shared_ptr cycle.
std::shared_ptr<SubscriptionManager> subscriptions; std::shared_ptr<SubscriptionManager> subscriptions;
std::shared_ptr<ETLLoadBalancer> const& balancer;
std::shared_ptr<WsBase> session; std::shared_ptr<WsBase> session;
Backend::LedgerRange const& range; Backend::LedgerRange const& range;
Counters& counters; Counters& counters;

View File

@@ -67,10 +67,6 @@ doAccountInfo(Context const& context)
context.backend->fetchLedgerObject(key.key, lgrInfo.seq, context.yield); context.backend->fetchLedgerObject(key.key, lgrInfo.seq, context.yield);
auto end = std::chrono::system_clock::now(); auto end = std::chrono::system_clock::now();
auto time =
std::chrono::duration_cast<std::chrono::microseconds>(end - start)
.count();
if (!dbResponse) if (!dbResponse)
{ {
return Status{Error::rpcACT_NOT_FOUND}; return Status{Error::rpcACT_NOT_FOUND};

View File

@@ -72,10 +72,6 @@ doLedgerData(Context const& context)
auto end = std::chrono::system_clock::now(); auto end = std::chrono::system_clock::now();
auto time =
std::chrono::duration_cast<std::chrono::microseconds>(end - start)
.count();
boost::json::object header; boost::json::object header;
if (!cursor) if (!cursor)
{ {

View File

@@ -347,9 +347,6 @@ doLedgerEntry(Context const& context)
auto dbResponse = auto dbResponse =
context.backend->fetchLedgerObject(key, lgrInfo.seq, context.yield); context.backend->fetchLedgerObject(key, lgrInfo.seq, context.yield);
auto end = std::chrono::system_clock::now(); auto end = std::chrono::system_clock::now();
auto time =
std::chrono::duration_cast<std::chrono::microseconds>(end - start)
.count();
if (!dbResponse or dbResponse->size() == 0) if (!dbResponse or dbResponse->size() == 0)
return Status{Error::rpcLGR_NOT_FOUND}; return Status{Error::rpcLGR_NOT_FOUND};
@@ -372,4 +369,4 @@ doLedgerEntry(Context const& context)
return response; return response;
} }
} // namespace RPC } // namespace RPC

View File

@@ -8,33 +8,63 @@
class DOSGuard class DOSGuard
{ {
std::unordered_map<std::string, uint32_t> ipFetchCount_;
uint32_t maxFetches_ = 100;
uint32_t sweepInterval_ = 1;
std::unordered_set<std::string> whitelist_;
boost::asio::io_context& ctx_; boost::asio::io_context& ctx_;
std::mutex mtx_; std::mutex mtx_; // protects ipFetchCount_
std::unordered_map<std::string, std::uint32_t> ipFetchCount_;
std::unordered_set<std::string> const whitelist_;
std::uint32_t const maxFetches_;
std::uint32_t const sweepInterval_;
std::optional<boost::json::object>
getConfig(boost::json::object const& config) const
{
if (!config.contains("dos_guard"))
return {};
return config.at("dos_guard").as_object();
}
std::uint32_t
get(boost::json::object const& config,
std::string const& key,
std::uint32_t const fallback) const
{
if (auto const c = getConfig(config); c)
return c->at(key).as_int64();
return fallback;
}
std::unordered_set<std::string> const
getWhitelist(boost::json::object const& config) const
{
using T = std::unordered_set<std::string> const;
auto const& c = getConfig(config);
if (!c)
return T();
if (!c->contains("whitelist"))
return T();
auto const& w = c->at("whitelist").as_array();
auto const transform = [](auto const& elem) {
return std::string(elem.as_string().c_str());
};
return T(
boost::transform_iterator(w.begin(), transform),
boost::transform_iterator(w.end(), transform));
}
public: public:
DOSGuard(boost::json::object const& config, boost::asio::io_context& ctx) DOSGuard(boost::json::object const& config, boost::asio::io_context& ctx)
: ctx_(ctx) : ctx_(ctx)
, whitelist_(getWhitelist(config))
, maxFetches_(get(config, "max_fetches", 100))
, sweepInterval_(get(config, "sweep_interval", 1))
{ {
if (config.contains("dos_guard"))
{
auto dosGuardConfig = config.at("dos_guard").as_object();
if (dosGuardConfig.contains("max_fetches") &&
dosGuardConfig.contains("sweep_interval"))
{
maxFetches_ = dosGuardConfig.at("max_fetches").as_int64();
sweepInterval_ = dosGuardConfig.at("sweep_interval").as_int64();
}
if (dosGuardConfig.contains("whitelist"))
{
auto whitelist = dosGuardConfig.at("whitelist").as_array();
for (auto& ip : whitelist)
whitelist_.insert(ip.as_string().c_str());
}
}
createTimer(); createTimer();
} }
@@ -55,20 +85,23 @@ public:
bool bool
isOk(std::string const& ip) isOk(std::string const& ip)
{ {
if (whitelist_.count(ip) > 0) if (whitelist_.contains(ip))
return true; return true;
std::unique_lock lck(mtx_); std::unique_lock lck(mtx_);
auto it = ipFetchCount_.find(ip); auto it = ipFetchCount_.find(ip);
if (it == ipFetchCount_.end()) if (it == ipFetchCount_.end())
return true; return true;
return it->second < maxFetches_; return it->second < maxFetches_;
} }
bool bool
add(std::string const& ip, uint32_t numObjects) add(std::string const& ip, uint32_t numObjects)
{ {
if (whitelist_.count(ip) > 0) if (whitelist_.contains(ip))
return true; return true;
{ {
std::unique_lock lck(mtx_); std::unique_lock lck(mtx_);
auto it = ipFetchCount_.find(ip); auto it = ipFetchCount_.find(ip);
@@ -77,6 +110,7 @@ public:
else else
it->second += numObjects; it->second += numObjects;
} }
return isOk(ip); return isOk(ip);
} }

View File

@@ -95,12 +95,12 @@ public:
boost::beast::flat_buffer&& b) boost::beast::flat_buffer&& b)
: ioc_(ioc) : ioc_(ioc)
, http_(std::move(socket)) , http_(std::move(socket))
, buffer_(std::move(b))
, backend_(backend) , backend_(backend)
, subscriptions_(subscriptions) , subscriptions_(subscriptions)
, balancer_(balancer) , balancer_(balancer)
, dosGuard_(dosGuard) , dosGuard_(dosGuard)
, counters_(counters) , counters_(counters)
, buffer_(std::move(b))
{ {
} }
WsUpgrader( WsUpgrader(
@@ -115,12 +115,12 @@ public:
http::request<http::string_body> req) http::request<http::string_body> req)
: ioc_(ioc) : ioc_(ioc)
, http_(std::move(stream)) , http_(std::move(stream))
, buffer_(std::move(b))
, backend_(backend) , backend_(backend)
, subscriptions_(subscriptions) , subscriptions_(subscriptions)
, balancer_(balancer) , balancer_(balancer)
, dosGuard_(dosGuard) , dosGuard_(dosGuard)
, counters_(counters) , counters_(counters)
, buffer_(std::move(b))
, req_(std::move(req)) , req_(std::move(req))
{ {
} }

View File

@@ -92,12 +92,12 @@ public:
boost::beast::flat_buffer&& b) boost::beast::flat_buffer&& b)
: ioc_(ioc) : ioc_(ioc)
, https_(std::move(socket), ctx) , https_(std::move(socket), ctx)
, buffer_(std::move(b))
, backend_(backend) , backend_(backend)
, subscriptions_(subscriptions) , subscriptions_(subscriptions)
, balancer_(balancer) , balancer_(balancer)
, dosGuard_(dosGuard) , dosGuard_(dosGuard)
, counters_(counters) , counters_(counters)
, buffer_(std::move(b))
{ {
} }
SslWsUpgrader( SslWsUpgrader(
@@ -112,12 +112,12 @@ public:
http::request<http::string_body> req) http::request<http::string_body> req)
: ioc_(ioc) : ioc_(ioc)
, https_(std::move(stream)) , https_(std::move(stream))
, buffer_(std::move(b))
, backend_(backend) , backend_(backend)
, subscriptions_(subscriptions) , subscriptions_(subscriptions)
, balancer_(balancer) , balancer_(balancer)
, dosGuard_(dosGuard) , dosGuard_(dosGuard)
, counters_(counters) , counters_(counters)
, buffer_(std::move(b))
, req_(std::move(req)) , req_(std::move(req))
{ {
} }

View File

@@ -101,13 +101,13 @@ public:
DOSGuard& dosGuard, DOSGuard& dosGuard,
RPC::Counters& counters, RPC::Counters& counters,
boost::beast::flat_buffer&& buffer) boost::beast::flat_buffer&& buffer)
: ioc_(ioc) : buffer_(std::move(buffer))
, ioc_(ioc)
, backend_(backend) , backend_(backend)
, subscriptions_(subscriptions) , subscriptions_(subscriptions)
, balancer_(balancer) , balancer_(balancer)
, dosGuard_(dosGuard) , dosGuard_(dosGuard)
, counters_(counters) , counters_(counters)
, buffer_(std::move(buffer))
{ {
} }
virtual ~WsSession() virtual ~WsSession()

View File

@@ -2203,7 +2203,6 @@ TEST(Backend, CacheIntegration)
for (auto obj : objs) for (auto obj : objs)
{ {
bool found = false; bool found = false;
bool correct = false;
for (auto retObj : retObjs) for (auto retObj : retObjs)
{ {
if (ripple::strHex(obj.first) == if (ripple::strHex(obj.first) ==