clean up subscriptions code

* change the way sessions are removed from SubscriptionManager
* make the locking more fine grained
* add a mutex to WsSession to protect concurrent access to the websocket
This commit is contained in:
CJ Cobb
2021-08-24 17:23:49 -04:00
parent 6fcfd57d13
commit f4b7a88d95
6 changed files with 152 additions and 139 deletions

View File

@@ -1,12 +1,12 @@
#ifndef REPORTING_RPC_H_INCLUDED
#define REPORTING_RPC_H_INCLUDED
#include <ripple/protocol/ErrorCodes.h>
#include <boost/json.hpp>
#include <backend/BackendInterface.h>
#include <optional>
#include <string>
#include <variant>
#include <boost/json.hpp>
#include <ripple/protocol/ErrorCodes.h>
#include <backend/BackendInterface.h>
/*
* This file contains various classes necessary for executing RPC handlers.
* Context gives the handlers access to various other parts of the application
@@ -22,8 +22,7 @@ class WsBase;
class SubscriptionManager;
class ETLLoadBalancer;
namespace RPC
{
namespace RPC {
struct Context
{
@@ -31,8 +30,11 @@ struct Context
std::uint32_t version;
boost::json::object const& params;
std::shared_ptr<BackendInterface> const& backend;
std::shared_ptr<SubscriptionManager> const& subscriptions;
std::shared_ptr<ETLLoadBalancer> const& balancer;
// this needs to be an actual shared_ptr, not a reference. The above
// references refer to shared_ptr members of WsBase, but WsBase contains
// SubscriptionManager as a weak_ptr, to prevent a shared_ptr cycle.
std::shared_ptr<SubscriptionManager> subscriptions;
std::shared_ptr<WsBase> session;
Backend::LedgerRange const& range;
@@ -45,15 +47,16 @@ struct Context
std::shared_ptr<ETLLoadBalancer> const& balancer_,
std::shared_ptr<WsBase> const& session_,
Backend::LedgerRange const& range_)
: method(command_)
, version(version_)
, params(params_)
, backend(backend_)
, subscriptions(subscriptions_)
, balancer(balancer_)
, session(session_)
, range(range_)
{}
: method(command_)
, version(version_)
, params(params_)
, backend(backend_)
, subscriptions(subscriptions_)
, balancer(balancer_)
, session(session_)
, range(range_)
{
}
};
using Error = ripple::error_code_i;
@@ -62,14 +65,14 @@ struct Status
Error error = Error::rpcSUCCESS;
std::string message = "";
Status() {};
Status(){};
Status(Error error_) : error(error_) {};
Status(Error error_) : error(error_){};
Status(Error error_, std::string message_)
: error(error_)
, message(message_)
{}
: error(error_), message(message_)
{
}
/** Returns true if the Status is *not* OK. */
operator bool() const
@@ -94,7 +97,6 @@ make_error(Error err);
boost::json::object
make_error(Error err, std::string const& message);
std::optional<Context>
make_WsContext(
boost::json::object const& request,
@@ -114,7 +116,7 @@ make_HttpContext(
Result
buildResponse(Context const& ctx);
} // namespace RPC
#endif //REPORTING_RPC_H_INCLUDED
} // namespace RPC
#endif // REPORTING_RPC_H_INCLUDED

View File

@@ -142,7 +142,7 @@ private:
on_handshake(boost::beast::error_code ec, std::size_t bytes_used)
{
if (ec)
return wsFail(ec, "handshake");
return logError(ec, "handshake");
// Consume the portion of the buffer used by the handshake
buffer_.consume(bytes_used);

View File

@@ -16,34 +16,6 @@ SubscriptionManager::unsubLedger(std::shared_ptr<WsBase>& session)
streamSubscribers_[Ledgers].erase(session);
}
void
SubscriptionManager::pubLedger(
ripple::LedgerInfo const& lgrInfo,
ripple::Fees const& fees,
std::string const& ledgerRange,
std::uint32_t txnCount)
{
boost::json::object pubMsg;
pubMsg["type"] = "ledgerClosed";
pubMsg["ledger_index"] = lgrInfo.seq;
pubMsg["ledger_hash"] = to_string(lgrInfo.hash);
pubMsg["ledger_time"] = lgrInfo.closeTime.time_since_epoch().count();
pubMsg["fee_ref"] = RPC::toBoostJson(fees.units.jsonClipped());
pubMsg["fee_base"] = RPC::toBoostJson(fees.base.jsonClipped());
pubMsg["reserve_base"] =
RPC::toBoostJson(fees.accountReserve(0).jsonClipped());
pubMsg["reserve_inc"] = RPC::toBoostJson(fees.increment.jsonClipped());
pubMsg["validated_ledgers"] = ledgerRange;
pubMsg["txn_count"] = txnCount;
std::lock_guard lk(m_);
for (auto const& session : streamSubscribers_[Ledgers])
session->send(boost::json::serialize(pubMsg));
}
void
SubscriptionManager::subTransactions(std::shared_ptr<WsBase>& session)
{
@@ -94,27 +66,70 @@ SubscriptionManager::unsubBook(
bookSubscribers_[book].erase(session);
}
void
SubscriptionManager::sendAll(
std::string const& pubMsg,
std::unordered_set<std::shared_ptr<WsBase>>& subs)
{
std::lock_guard lk(m_);
for (auto it = subs.begin(); it != subs.end();)
{
auto& session = *it;
if (session->dead())
{
it = subs.erase(it);
}
else
{
session->publishToStream(pubMsg);
++it;
}
}
}
void
SubscriptionManager::pubLedger(
ripple::LedgerInfo const& lgrInfo,
ripple::Fees const& fees,
std::string const& ledgerRange,
std::uint32_t txnCount)
{
boost::json::object pubMsg;
pubMsg["type"] = "ledgerClosed";
pubMsg["ledger_index"] = lgrInfo.seq;
pubMsg["ledger_hash"] = to_string(lgrInfo.hash);
pubMsg["ledger_time"] = lgrInfo.closeTime.time_since_epoch().count();
pubMsg["fee_ref"] = RPC::toBoostJson(fees.units.jsonClipped());
pubMsg["fee_base"] = RPC::toBoostJson(fees.base.jsonClipped());
pubMsg["reserve_base"] =
RPC::toBoostJson(fees.accountReserve(0).jsonClipped());
pubMsg["reserve_inc"] = RPC::toBoostJson(fees.increment.jsonClipped());
pubMsg["validated_ledgers"] = ledgerRange;
pubMsg["txn_count"] = txnCount;
sendAll(boost::json::serialize(pubMsg), streamSubscribers_[Ledgers]);
}
void
SubscriptionManager::pubTransaction(
Backend::TransactionAndMetadata const& blobs,
std::uint32_t seq)
{
std::lock_guard lk(m_);
auto [tx, meta] = RPC::deserializeTxPlusMeta(blobs, seq);
boost::json::object pubMsg;
pubMsg["transaction"] = RPC::toJson(*tx);
pubMsg["meta"] = RPC::toJson(*meta);
for (auto const& session : streamSubscribers_[Transactions])
session->send(boost::json::serialize(pubMsg));
boost::json::object pubObj;
pubObj["transaction"] = RPC::toJson(*tx);
pubObj["meta"] = RPC::toJson(*meta);
std::string pubMsg{boost::json::serialize(pubObj)};
sendAll(pubMsg, streamSubscribers_[Transactions]);
auto journal = ripple::debugLog();
auto accounts = meta->getAffectedAccounts(journal);
for (ripple::AccountID const& account : accounts)
for (auto const& session : accountSubscribers_[account])
session->send(boost::json::serialize(pubMsg));
sendAll(pubMsg, accountSubscribers_[account]);
for (auto const& node : meta->peekNodes())
{
@@ -145,8 +160,7 @@ SubscriptionManager::pubTransaction(
ripple::Book book{
data->getFieldAmount(ripple::sfTakerGets).issue(),
data->getFieldAmount(ripple::sfTakerPays).issue()};
for (auto const& session : bookSubscribers_[book])
session->send(boost::json::serialize(pubMsg));
sendAll(pubMsg, bookSubscribers_[book]);
}
}
}
@@ -157,16 +171,14 @@ void
SubscriptionManager::forwardProposedTransaction(
boost::json::object const& response)
{
std::lock_guard lk(m_);
for (auto const& session : streamSubscribers_[TransactionsProposed])
session->send(boost::json::serialize(response));
std::string pubMsg{boost::json::serialize(pubMsg)};
sendAll(pubMsg, streamSubscribers_[TransactionsProposed]);
auto transaction = response.at("transaction").as_object();
auto accounts = RPC::getAccountsFromTransaction(transaction);
for (ripple::AccountID const& account : accounts)
for (auto const& session : accountProposedSubscribers_[account])
session->send(boost::json::serialize(response));
sendAll(pubMsg, accountProposedSubscribers_[account]);
}
void
@@ -201,25 +213,3 @@ SubscriptionManager::unsubProposedTransactions(std::shared_ptr<WsBase>& session)
streamSubscribers_[TransactionsProposed].erase(session);
}
void
SubscriptionManager::clearSession(WsBase* s)
{
std::lock_guard lk(m_);
// need the == operator. No-op delete
std::shared_ptr<WsBase> targetSession(s, [](WsBase*) {});
for (auto& stream : streamSubscribers_)
stream.erase(targetSession);
for (auto& [account, subscribers] : accountSubscribers_)
{
if (subscribers.find(targetSession) != subscribers.end())
accountSubscribers_[account].erase(targetSession);
}
for (auto& [account, subscribers] : accountProposedSubscribers_)
{
if (subscribers.find(targetSession) != subscribers.end())
accountProposedSubscribers_[account].erase(targetSession);
}
}

View File

@@ -27,7 +27,7 @@ class WsBase;
class SubscriptionManager
{
using subscriptions = std::set<std::shared_ptr<WsBase>>;
using subscriptions = std::unordered_set<std::shared_ptr<WsBase>>;
enum SubscriptionType {
Ledgers,
@@ -109,8 +109,11 @@ public:
void
unsubProposedTransactions(std::shared_ptr<WsBase>& session);
private:
void
clearSession(WsBase* session);
sendAll(
std::string const& pubMsg,
std::unordered_set<std::shared_ptr<WsBase>>& subs);
};
#endif // SUBSCRIPTION_MANAGER_H

View File

@@ -39,9 +39,10 @@ namespace websocket = boost::beast::websocket;
using tcp = boost::asio::ip::tcp;
inline void
wsFail(boost::beast::error_code ec, char const* what)
logError(boost::beast::error_code ec, char const* what)
{
std::cerr << what << ": " << ec.message() << "\n";
BOOST_LOG_TRIVIAL(debug)
<< __func__ << " : " << what << ": " << ec.message() << "\n";
}
inline boost::json::object
@@ -60,14 +61,29 @@ getDefaultWsResponse(boost::json::value const& id)
class WsBase
{
std::atomic_bool dead_ = false;
public:
// Send, that enables SubscriptionManager to publish to clients
virtual void
send(std::string&& msg) = 0;
publishToStream(std::string const& msg) = 0;
virtual ~WsBase()
{
}
void
wsFail(boost::beast::error_code ec, char const* what)
{
logError(ec, what);
dead_ = true;
}
bool
dead()
{
return dead_;
}
};
class SubscriptionManager;
@@ -81,12 +97,16 @@ class WsSession : public WsBase,
using std::enable_shared_from_this<WsSession<Derived>>::shared_from_this;
boost::beast::flat_buffer buffer_;
std::string response_;
std::string responseBuffer_;
std::shared_ptr<BackendInterface> backend_;
// has to be a weak ptr because SubscriptionManager maintains collections
// of std::shared_ptr<WsBase> objects. If this were shared, there would be
// a cyclical dependency that would block destruction
std::weak_ptr<SubscriptionManager> subscriptions_;
std::shared_ptr<ETLLoadBalancer> balancer_;
DOSGuard& dosGuard_;
std::mutex mtx_;
public:
explicit WsSession(
@@ -115,11 +135,31 @@ public:
}
void
send(std::string&& msg)
publishToStream(std::string const& msg)
{
// msg might not stick around for as long as it takes to write, so we
// make a copy and capture it in the lambda
auto msgSaved = std::make_unique<std::string>(msg);
std::lock_guard<std::mutex> lck(mtx_);
derived().ws().text(derived().ws().got_text());
// we send from SubscriptionManager as well as from on_read
derived().ws().async_write(
boost::asio::buffer(*msgSaved),
[m = std::move(msgSaved), shared = shared_from_this()](
auto ec, size_t size) {
if (ec)
return shared->wsFail(ec, "publishToStream");
});
}
void
sendResponse(boost::json::object const& object)
{
std::lock_guard<std::mutex> lck(mtx_);
responseBuffer_ = boost::json::serialize(object);
derived().ws().text(derived().ws().got_text());
derived().ws().async_write(
boost::asio::buffer(msg),
boost::asio::buffer(responseBuffer_),
boost::beast::bind_front_handler(
&WsSession::on_write, this->shared_from_this()));
}
@@ -156,20 +196,10 @@ public:
do_read();
}
void
do_close()
{
std::shared_ptr<SubscriptionManager> mgr = subscriptions_.lock();
if (!mgr)
return;
mgr->clearSession(this);
}
void
do_read()
{
std::lock_guard<std::mutex> lck{mtx_};
// Read a message into our buffer
derived().ws().async_read(
buffer_,
@@ -182,18 +212,11 @@ public:
{
boost::ignore_unused(bytes_transferred);
// This indicates that the session was closed
if (ec == boost::beast::websocket::error::closed)
{
return;
}
if (ec)
return wsFail(ec, "read");
std::string msg{
static_cast<char const*>(buffer_.data().data()), buffer_.size()};
// BOOST_LOG_TRIVIAL(debug) << __func__ << msg;
boost::json::object response;
auto ip = derived().ip();
BOOST_LOG_TRIVIAL(debug)
@@ -211,8 +234,8 @@ public:
{
auto range = backend_->fetchLedgerRange();
if (!range)
return send(boost::json::serialize(
RPC::make_error(RPC::Error::rpcNOT_READY)));
return sendResponse(
RPC::make_error(RPC::Error::rpcNOT_READY));
std::optional<RPC::Context> context = RPC::make_WsContext(
request,
@@ -223,13 +246,13 @@ public:
*range);
if (!context)
return send(boost::json::serialize(
RPC::make_error(RPC::Error::rpcBAD_SYNTAX)));
return sendResponse(
RPC::make_error(RPC::Error::rpcBAD_SYNTAX));
auto id =
request.contains("id") ? request.at("id") : nullptr;
auto response = getDefaultWsResponse(id);
response = getDefaultWsResponse(id);
boost::json::object& result =
response["result"].as_object();
@@ -243,23 +266,18 @@ public:
if (!id.is_null())
error["id"] = id;
error["request"] = request;
response_ = boost::json::serialize(error);
response = error;
}
else
{
auto json = std::get<boost::json::object>(v);
response_ = boost::json::serialize(json);
response = std::get<boost::json::object>(v);
}
if (!dosGuard_.add(ip, response_.size()))
result["warning"] = "Too many requests";
}
catch (Backend::DatabaseTimeout const& t)
{
BOOST_LOG_TRIVIAL(error) << __func__ << " Database timeout";
return send(boost::json::serialize(
RPC::make_error(RPC::Error::rpcNOT_READY)));
return sendResponse(
RPC::make_error(RPC::Error::rpcNOT_READY));
}
}
catch (std::exception const& e)
@@ -267,13 +285,13 @@ public:
BOOST_LOG_TRIVIAL(error)
<< __func__ << " caught exception : " << e.what();
return send(boost::json::serialize(
RPC::make_error(RPC::Error::rpcINTERNAL)));
return sendResponse(RPC::make_error(RPC::Error::rpcINTERNAL));
}
}
BOOST_LOG_TRIVIAL(trace) << __func__ << response_;
BOOST_LOG_TRIVIAL(trace)
<< __func__ << " : " << boost::json::serialize(response);
send(std::move(response_));
sendResponse(response);
}
void
@@ -286,7 +304,7 @@ public:
// Clear the buffer
buffer_.consume(buffer_.size());
responseBuffer_.clear();
// Do another read
do_read();
}

View File

@@ -813,8 +813,8 @@ async def subscribe(ip, port):
address = 'ws://' + str(ip) + ':' + str(port)
try:
async with websockets.connect(address) as ws:
await ws.send(json.dumps({"command":"subscribe","streams":["ledger"],"books":[{"taker_pays":{"currency":"XRP"},"taker_gets":{"currency":"USD","issuer":"rhub8VRN55s94qWKDv6jmDy1pUykJzF3wq"}}]}))
#await ws.send(json.dumps({"command":"subscribe","streams":["ledger","transactions"]}))
await ws.send(json.dumps({"command":"subscribe","streams":["ledger","transactions"],"books":[{"taker_pays":{"currency":"XRP"},"taker_gets":{"currency":"USD","issuer":"rhub8VRN55s94qWKDv6jmDy1pUykJzF3wq"}}]}))
#await ws.send(json.dumps({"command":"subscribe","streams":["ledger"]}))
while True:
res = json.loads(await ws.recv())
print(json.dumps(res,indent=4,sort_keys=True))