Hotfix for account_info, ledger_data and ip() (#134)

* remove option to create account from seed

* catch errors in derived().ip()

* checks keys.size() before accessing
This commit is contained in:
Nathan Nichols
2022-03-25 15:57:32 -05:00
committed by GitHub
parent ac6c4c25d6
commit e526083456
8 changed files with 76 additions and 43 deletions

View File

@@ -253,9 +253,8 @@ BackendInterface::fetchLedgerPage(
bool reachedEnd = false; bool reachedEnd = false;
while (keys.size() < limit && !reachedEnd) while (keys.size() < limit && !reachedEnd)
{ {
ripple::uint256 const& curCursor = keys.size() ? keys.back() ripple::uint256 const& curCursor =
: cursor ? *cursor keys.size() ? keys.back() : cursor ? *cursor : firstKey;
: firstKey;
uint32_t seq = outOfOrder ? range->maxSequence : ledgerSequence; uint32_t seq = outOfOrder ? range->maxSequence : ledgerSequence;
auto succ = fetchSuccessorKey(curCursor, seq, yield); auto succ = fetchSuccessorKey(curCursor, seq, yield);
if (!succ) if (!succ)
@@ -283,7 +282,7 @@ BackendInterface::fetchLedgerPage(
assert(false); assert(false);
} }
} }
if (!reachedEnd) if (keys.size() && !reachedEnd)
page.cursor = keys.back(); page.cursor = keys.back();
return page; return page;

View File

@@ -48,16 +48,8 @@ doAccountInfo(Context const& context)
// Get info on account. // Get info on account.
auto accountID = accountFromStringStrict(strIdent); auto accountID = accountFromStringStrict(strIdent);
if (!accountID) if (!accountID)
{
if (!request.contains("strict") || !request.at("strict").as_bool())
{
accountID = accountFromSeed(strIdent);
if (!accountID)
return Status{Error::rpcBAD_SEED};
}
else
return Status{Error::rpcACT_MALFORMED}; return Status{Error::rpcACT_MALFORMED};
}
assert(accountID.has_value()); assert(accountID.has_value());
auto key = ripple::keylet::account(accountID.value()); auto key = ripple::keylet::account(accountID.value());

View File

@@ -192,6 +192,10 @@ public:
} }
auto ip = derived().ip(); auto ip = derived().ip();
if (!ip)
return;
auto session = derived().shared_from_this(); auto session = derived().shared_from_this();
// Requests are handed using coroutines. Here we spawn a coroutine // Requests are handed using coroutines. Here we spawn a coroutine
@@ -208,7 +212,7 @@ public:
etl_, etl_,
dosGuard_, dosGuard_,
counters_, counters_,
ip, *ip,
session); session);
}); });
} }

View File

@@ -50,11 +50,18 @@ public:
return std::move(stream_); return std::move(stream_);
} }
std::string std::optional<std::string>
ip() ip()
{
try
{ {
return stream_.socket().remote_endpoint().address().to_string(); return stream_.socket().remote_endpoint().address().to_string();
} }
catch (std::exception const&)
{
return {};
}
}
// Start the asynchronous operation // Start the asynchronous operation
void void

View File

@@ -58,8 +58,10 @@ public:
return ws_; return ws_;
} }
std::string std::optional<std::string>
ip() ip()
{
try
{ {
return ws() return ws()
.next_layer() .next_layer()
@@ -68,6 +70,11 @@ public:
.address() .address()
.to_string(); .to_string();
} }
catch (std::exception const&)
{
return {};
}
}
~PlainWsSession() = default; ~PlainWsSession() = default;
}; };

View File

@@ -51,8 +51,10 @@ public:
return std::move(stream_); return std::move(stream_);
} }
std::string std::optional<std::string>
ip() ip()
{
try
{ {
return stream_.next_layer() return stream_.next_layer()
.socket() .socket()
@@ -60,6 +62,11 @@ public:
.address() .address()
.to_string(); .to_string();
} }
catch (std::exception const&)
{
return {};
}
}
// Start the asynchronous operation // Start the asynchronous operation
void void

View File

@@ -55,8 +55,11 @@ public:
{ {
return ws_; return ws_;
} }
std::string
std::optional<std::string>
ip() ip()
{
try
{ {
return ws() return ws()
.next_layer() .next_layer()
@@ -66,6 +69,11 @@ public:
.address() .address()
.to_string(); .to_string();
} }
catch (std::exception const&)
{
return {};
}
}
}; };
class SslWsUpgrader : public std::enable_shared_from_this<SslWsUpgrader> class SslWsUpgrader : public std::enable_shared_from_this<SslWsUpgrader>

View File

@@ -213,6 +213,10 @@ public:
void void
handle_request(std::string const&& msg, boost::asio::yield_context& yc) handle_request(std::string const&& msg, boost::asio::yield_context& yc)
{ {
auto ip = derived().ip();
if (!ip)
return;
boost::json::object response = {}; boost::json::object response = {};
auto sendError = [this](auto error) { auto sendError = [this](auto error) {
send(boost::json::serialize(RPC::make_error(error))); send(boost::json::serialize(RPC::make_error(error)));
@@ -221,6 +225,7 @@ public:
{ {
boost::json::value raw = boost::json::parse(msg); boost::json::value raw = boost::json::parse(msg);
boost::json::object request = raw.as_object(); boost::json::object request = raw.as_object();
BOOST_LOG_TRIVIAL(debug) << " received request : " << request; BOOST_LOG_TRIVIAL(debug) << " received request : " << request;
try try
{ {
@@ -238,7 +243,7 @@ public:
shared_from_this(), shared_from_this(),
*range, *range,
counters_, counters_,
derived().ip()); *ip);
if (!context) if (!context)
return sendError(RPC::Error::rpcBAD_SYNTAX); return sendError(RPC::Error::rpcBAD_SYNTAX);
@@ -287,7 +292,7 @@ public:
} }
std::string responseStr = boost::json::serialize(response); std::string responseStr = boost::json::serialize(response);
dosGuard_.add(derived().ip(), responseStr.size()); dosGuard_.add(*ip, responseStr.size());
send(std::move(responseStr)); send(std::move(responseStr));
} }
@@ -302,9 +307,13 @@ public:
std::string msg{ std::string msg{
static_cast<char const*>(buffer_.data().data()), buffer_.size()}; static_cast<char const*>(buffer_.data().data()), buffer_.size()};
auto ip = derived().ip(); auto ip = derived().ip();
if (!ip)
return;
BOOST_LOG_TRIVIAL(debug) BOOST_LOG_TRIVIAL(debug)
<< __func__ << " received request from ip = " << ip; << __func__ << " received request from ip = " << *ip;
if (!dosGuard_.isOk(ip)) if (!dosGuard_.isOk(*ip))
{ {
boost::json::object response; boost::json::object response;
response["error"] = "Too many requests. Slow down"; response["error"] = "Too many requests. Slow down";
@@ -312,7 +321,7 @@ public:
BOOST_LOG_TRIVIAL(trace) << __func__ << " : " << responseStr; BOOST_LOG_TRIVIAL(trace) << __func__ << " : " << responseStr;
dosGuard_.add(ip, responseStr.size()); dosGuard_.add(*ip, responseStr.size());
send(std::move(responseStr)); send(std::move(responseStr));
} }
else else