refactor to make WsSession thread safe

This commit is contained in:
CJ Cobb
2021-08-30 14:21:16 -04:00
parent 5427467ce2
commit ce9a2af33c
6 changed files with 176 additions and 79 deletions

View File

@@ -32,11 +32,11 @@
#include <functional>
#include <iostream>
#include <memory>
#include <webserver/Listener.h>
#include <sstream>
#include <string>
#include <thread>
#include <vector>
#include <webserver/Listener.h>
std::optional<boost::json::object>
parse_config(const char* filename)
@@ -149,7 +149,7 @@ main(int argc, char* argv[])
// Manages clients subscribed to streams
std::shared_ptr<SubscriptionManager> subscriptions{
SubscriptionManager::make_SubscriptionManager()};
SubscriptionManager::make_SubscriptionManager(backend)};
// Tracks which ledgers have been validated by the
// network

View File

@@ -31,7 +31,7 @@ validateStreams(boost::json::object const& request)
return OK;
}
void
boost::json::object
subscribeToStreams(
boost::json::object const& request,
std::shared_ptr<WsBase> session,
@@ -39,12 +39,13 @@ subscribeToStreams(
{
boost::json::array const& streams = request.at("streams").as_array();
boost::json::object response;
for (auto const& stream : streams)
{
std::string s = stream.as_string().c_str();
if (s == "ledger")
manager.subLedger(session);
response = manager.subLedger(session);
else if (s == "transactions")
manager.subTransactions(session);
else if (s == "transactions_proposed")
@@ -52,6 +53,7 @@ subscribeToStreams(
else
assert(false);
}
return response;
}
void
@@ -320,8 +322,10 @@ doSubscribe(Context const& context)
snapshot = std::move(snap);
}
boost::json::object response;
if (request.contains("streams"))
subscribeToStreams(request, context.session, *context.subscriptions);
response = subscribeToStreams(
request, context.session, *context.subscriptions);
if (request.contains("accounts"))
subscribeToAccounts(request, context.session, *context.subscriptions);
@@ -333,7 +337,6 @@ doSubscribe(Context const& context)
if (request.contains("books"))
subscribeToBooks(books, context.session, *context.subscriptions);
boost::json::object response = {{"status", "success"}};
if (snapshot.size())
response["offers"] = snapshot;
return response;

View File

@@ -2,11 +2,54 @@
#include <webserver/SubscriptionManager.h>
#include <webserver/WsBase.h>
void
boost::json::object
getLedgerPubMessage(
ripple::LedgerInfo const& lgrInfo,
ripple::Fees const& fees,
std::string const& ledgerRange,
std::uint32_t txnCount)
{
boost::json::object pubMsg;
pubMsg["type"] = "ledgerClosed";
pubMsg["ledger_index"] = lgrInfo.seq;
pubMsg["ledger_hash"] = to_string(lgrInfo.hash);
pubMsg["ledger_time"] = lgrInfo.closeTime.time_since_epoch().count();
pubMsg["fee_ref"] = RPC::toBoostJson(fees.units.jsonClipped());
pubMsg["fee_base"] = RPC::toBoostJson(fees.base.jsonClipped());
pubMsg["reserve_base"] =
RPC::toBoostJson(fees.accountReserve(0).jsonClipped());
pubMsg["reserve_inc"] = RPC::toBoostJson(fees.increment.jsonClipped());
pubMsg["validated_ledgers"] = ledgerRange;
pubMsg["txn_count"] = txnCount;
return pubMsg;
}
boost::json::object
SubscriptionManager::subLedger(std::shared_ptr<WsBase>& session)
{
std::lock_guard lk(m_);
streamSubscribers_[Ledgers].emplace(session);
{
std::lock_guard lk(m_);
streamSubscribers_[Ledgers].emplace(session);
}
auto ledgerRange = backend_->fetchLedgerRange();
assert(ledgerRange);
auto lgrInfo = backend_->fetchLedgerBySequence(ledgerRange->maxSequence);
assert(lgrInfo);
std::optional<ripple::Fees> fees;
fees = backend_->fetchFees(lgrInfo->seq);
assert(fees);
std::string range = std::to_string(ledgerRange->minSequence) + "-" +
std::to_string(ledgerRange->maxSequence);
auto pubMsg = getLedgerPubMessage(*lgrInfo, *fees, range, 0);
pubMsg.erase("txn_count");
pubMsg.erase("type");
return pubMsg;
}
void
@@ -81,7 +124,7 @@ SubscriptionManager::sendAll(
}
else
{
session->publishToStream(pubMsg);
session->send(pubMsg);
++it;
}
}
@@ -94,23 +137,11 @@ SubscriptionManager::pubLedger(
std::string const& ledgerRange,
std::uint32_t txnCount)
{
boost::json::object pubMsg;
pubMsg["type"] = "ledgerClosed";
pubMsg["ledger_index"] = lgrInfo.seq;
pubMsg["ledger_hash"] = to_string(lgrInfo.hash);
pubMsg["ledger_time"] = lgrInfo.closeTime.time_since_epoch().count();
pubMsg["fee_ref"] = RPC::toBoostJson(fees.units.jsonClipped());
pubMsg["fee_base"] = RPC::toBoostJson(fees.base.jsonClipped());
pubMsg["reserve_base"] =
RPC::toBoostJson(fees.accountReserve(0).jsonClipped());
pubMsg["reserve_inc"] = RPC::toBoostJson(fees.increment.jsonClipped());
pubMsg["validated_ledgers"] = ledgerRange;
pubMsg["txn_count"] = txnCount;
sendAll(boost::json::serialize(pubMsg), streamSubscribers_[Ledgers]);
std::cout << "publishing ledger" << std::endl;
sendAll(
boost::json::serialize(
getLedgerPubMessage(lgrInfo, fees, ledgerRange, txnCount)),
streamSubscribers_[Ledgers]);
}
void
@@ -123,6 +154,7 @@ SubscriptionManager::pubTransaction(
pubObj["transaction"] = RPC::toJson(*tx);
pubObj["meta"] = RPC::toJson(*meta);
std::string pubMsg{boost::json::serialize(pubObj)};
std::cout << "publishing txn" << std::endl;
sendAll(pubMsg, streamSubscribers_[Transactions]);
auto journal = ripple::debugLog();
@@ -171,7 +203,7 @@ void
SubscriptionManager::forwardProposedTransaction(
boost::json::object const& response)
{
std::string pubMsg{boost::json::serialize(pubMsg)};
std::string pubMsg{boost::json::serialize(response)};
sendAll(pubMsg, streamSubscribers_[TransactionsProposed]);
auto transaction = response.at("transaction").as_object();

View File

@@ -42,15 +42,22 @@ class SubscriptionManager
std::unordered_map<ripple::AccountID, subscriptions>
accountProposedSubscribers_;
std::unordered_map<ripple::Book, subscriptions> bookSubscribers_;
std::shared_ptr<Backend::BackendInterface> backend_;
public:
static std::shared_ptr<SubscriptionManager>
make_SubscriptionManager()
make_SubscriptionManager(
std::shared_ptr<Backend::BackendInterface> const& b)
{
return std::make_shared<SubscriptionManager>();
return std::make_shared<SubscriptionManager>(b);
}
void
SubscriptionManager(std::shared_ptr<Backend::BackendInterface> const& b)
: backend_(b)
{
}
boost::json::object
subLedger(std::shared_ptr<WsBase>& session);
void

View File

@@ -66,7 +66,7 @@ class WsBase
public:
// Send, that enables SubscriptionManager to publish to clients
virtual void
publishToStream(std::string const& msg) = 0;
send(std::string const& msg) = 0;
virtual ~WsBase()
{
@@ -97,7 +97,6 @@ class WsSession : public WsBase,
using std::enable_shared_from_this<WsSession<Derived>>::shared_from_this;
boost::beast::flat_buffer buffer_;
std::string responseBuffer_;
std::shared_ptr<BackendInterface> backend_;
// has to be a weak ptr because SubscriptionManager maintains collections
@@ -107,6 +106,7 @@ class WsSession : public WsBase,
std::shared_ptr<ETLLoadBalancer> balancer_;
DOSGuard& dosGuard_;
std::mutex mtx_;
std::queue<std::string> messages_;
public:
explicit WsSession(
@@ -135,33 +135,37 @@ public:
}
void
publishToStream(std::string const& msg)
sendNext()
{
// msg might not stick around for as long as it takes to write, so we
// make a copy and capture it in the lambda
auto msgSaved = std::make_unique<std::string>(msg);
std::lock_guard<std::mutex> lck(mtx_);
derived().ws().text(derived().ws().got_text());
// we send from SubscriptionManager as well as from on_read
derived().ws().async_write(
boost::asio::buffer(*msgSaved),
[m = std::move(msgSaved), shared = shared_from_this()](
auto ec, size_t size) {
boost::asio::buffer(messages_.front()),
[shared = shared_from_this()](auto ec, size_t size) {
if (ec)
return shared->wsFail(ec, "publishToStream");
size_t left = 0;
{
std::lock_guard<std::mutex> lck(shared->mtx_);
shared->messages_.pop();
left = shared->messages_.size();
}
if (left > 0)
shared->sendNext();
});
}
void
sendResponse(boost::json::object const& object)
send(std::string const& msg)
{
std::lock_guard<std::mutex> lck(mtx_);
responseBuffer_ = boost::json::serialize(object);
derived().ws().text(derived().ws().got_text());
derived().ws().async_write(
boost::asio::buffer(responseBuffer_),
boost::beast::bind_front_handler(
&WsSession::on_write, this->shared_from_this()));
size_t left = 0;
{
std::lock_guard<std::mutex> lck(mtx_);
messages_.push(msg);
left = messages_.size();
}
// if the queue was previously empty, start the send chain
if (left == 1)
sendNext();
}
void
@@ -200,6 +204,8 @@ public:
do_read()
{
std::lock_guard<std::mutex> lck{mtx_};
// Clear the buffer
buffer_.consume(buffer_.size());
// Read a message into our buffer
derived().ws().async_read(
buffer_,
@@ -225,6 +231,9 @@ public:
response["error"] = "Too many requests. Slow down";
else
{
auto sendError = [this](auto error) {
send(boost::json::serialize(RPC::make_error(error)));
};
try
{
boost::json::value raw = boost::json::parse(msg);
@@ -234,8 +243,7 @@ public:
{
auto range = backend_->fetchLedgerRange();
if (!range)
return sendResponse(
RPC::make_error(RPC::Error::rpcNOT_READY));
return sendError(RPC::Error::rpcNOT_READY);
std::optional<RPC::Context> context = RPC::make_WsContext(
request,
@@ -246,8 +254,7 @@ public:
*range);
if (!context)
return sendResponse(
RPC::make_error(RPC::Error::rpcBAD_SYNTAX));
return sendError(RPC::Error::rpcBAD_SYNTAX);
auto id =
request.contains("id") ? request.at("id") : nullptr;
@@ -270,14 +277,13 @@ public:
}
else
{
response = std::get<boost::json::object>(v);
result = std::get<boost::json::object>(v);
}
}
catch (Backend::DatabaseTimeout const& t)
{
BOOST_LOG_TRIVIAL(error) << __func__ << " Database timeout";
return sendResponse(
RPC::make_error(RPC::Error::rpcNOT_READY));
return sendError(RPC::Error::rpcNOT_READY);
}
}
catch (std::exception const& e)
@@ -285,27 +291,13 @@ public:
BOOST_LOG_TRIVIAL(error)
<< __func__ << " caught exception : " << e.what();
return sendResponse(RPC::make_error(RPC::Error::rpcINTERNAL));
return sendError(RPC::Error::rpcINTERNAL);
}
}
BOOST_LOG_TRIVIAL(trace)
<< __func__ << " : " << boost::json::serialize(response);
sendResponse(response);
}
void
on_write(boost::beast::error_code ec, std::size_t bytes_transferred)
{
boost::ignore_unused(bytes_transferred);
if (ec)
return wsFail(ec, "write");
// Clear the buffer
buffer_.consume(buffer_.size());
responseBuffer_.clear();
// Do another read
send(boost::json::serialize(response));
do_read();
}
};