mirror of
https://github.com/XRPLF/clio.git
synced 2025-11-20 11:45:53 +00:00
refactor to make WsSession thread safe
This commit is contained in:
@@ -32,11 +32,11 @@
|
||||
#include <functional>
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <webserver/Listener.h>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
#include <webserver/Listener.h>
|
||||
|
||||
std::optional<boost::json::object>
|
||||
parse_config(const char* filename)
|
||||
@@ -149,7 +149,7 @@ main(int argc, char* argv[])
|
||||
|
||||
// Manages clients subscribed to streams
|
||||
std::shared_ptr<SubscriptionManager> subscriptions{
|
||||
SubscriptionManager::make_SubscriptionManager()};
|
||||
SubscriptionManager::make_SubscriptionManager(backend)};
|
||||
|
||||
// Tracks which ledgers have been validated by the
|
||||
// network
|
||||
|
||||
@@ -31,7 +31,7 @@ validateStreams(boost::json::object const& request)
|
||||
return OK;
|
||||
}
|
||||
|
||||
void
|
||||
boost::json::object
|
||||
subscribeToStreams(
|
||||
boost::json::object const& request,
|
||||
std::shared_ptr<WsBase> session,
|
||||
@@ -39,12 +39,13 @@ subscribeToStreams(
|
||||
{
|
||||
boost::json::array const& streams = request.at("streams").as_array();
|
||||
|
||||
boost::json::object response;
|
||||
for (auto const& stream : streams)
|
||||
{
|
||||
std::string s = stream.as_string().c_str();
|
||||
|
||||
if (s == "ledger")
|
||||
manager.subLedger(session);
|
||||
response = manager.subLedger(session);
|
||||
else if (s == "transactions")
|
||||
manager.subTransactions(session);
|
||||
else if (s == "transactions_proposed")
|
||||
@@ -52,6 +53,7 @@ subscribeToStreams(
|
||||
else
|
||||
assert(false);
|
||||
}
|
||||
return response;
|
||||
}
|
||||
|
||||
void
|
||||
@@ -320,8 +322,10 @@ doSubscribe(Context const& context)
|
||||
snapshot = std::move(snap);
|
||||
}
|
||||
|
||||
boost::json::object response;
|
||||
if (request.contains("streams"))
|
||||
subscribeToStreams(request, context.session, *context.subscriptions);
|
||||
response = subscribeToStreams(
|
||||
request, context.session, *context.subscriptions);
|
||||
|
||||
if (request.contains("accounts"))
|
||||
subscribeToAccounts(request, context.session, *context.subscriptions);
|
||||
@@ -333,7 +337,6 @@ doSubscribe(Context const& context)
|
||||
if (request.contains("books"))
|
||||
subscribeToBooks(books, context.session, *context.subscriptions);
|
||||
|
||||
boost::json::object response = {{"status", "success"}};
|
||||
if (snapshot.size())
|
||||
response["offers"] = snapshot;
|
||||
return response;
|
||||
|
||||
@@ -2,11 +2,54 @@
|
||||
#include <webserver/SubscriptionManager.h>
|
||||
#include <webserver/WsBase.h>
|
||||
|
||||
void
|
||||
boost::json::object
|
||||
getLedgerPubMessage(
|
||||
ripple::LedgerInfo const& lgrInfo,
|
||||
ripple::Fees const& fees,
|
||||
std::string const& ledgerRange,
|
||||
std::uint32_t txnCount)
|
||||
{
|
||||
boost::json::object pubMsg;
|
||||
|
||||
pubMsg["type"] = "ledgerClosed";
|
||||
pubMsg["ledger_index"] = lgrInfo.seq;
|
||||
pubMsg["ledger_hash"] = to_string(lgrInfo.hash);
|
||||
pubMsg["ledger_time"] = lgrInfo.closeTime.time_since_epoch().count();
|
||||
|
||||
pubMsg["fee_ref"] = RPC::toBoostJson(fees.units.jsonClipped());
|
||||
pubMsg["fee_base"] = RPC::toBoostJson(fees.base.jsonClipped());
|
||||
pubMsg["reserve_base"] =
|
||||
RPC::toBoostJson(fees.accountReserve(0).jsonClipped());
|
||||
pubMsg["reserve_inc"] = RPC::toBoostJson(fees.increment.jsonClipped());
|
||||
|
||||
pubMsg["validated_ledgers"] = ledgerRange;
|
||||
pubMsg["txn_count"] = txnCount;
|
||||
return pubMsg;
|
||||
}
|
||||
|
||||
boost::json::object
|
||||
SubscriptionManager::subLedger(std::shared_ptr<WsBase>& session)
|
||||
{
|
||||
std::lock_guard lk(m_);
|
||||
streamSubscribers_[Ledgers].emplace(session);
|
||||
{
|
||||
std::lock_guard lk(m_);
|
||||
streamSubscribers_[Ledgers].emplace(session);
|
||||
}
|
||||
auto ledgerRange = backend_->fetchLedgerRange();
|
||||
assert(ledgerRange);
|
||||
auto lgrInfo = backend_->fetchLedgerBySequence(ledgerRange->maxSequence);
|
||||
assert(lgrInfo);
|
||||
|
||||
std::optional<ripple::Fees> fees;
|
||||
fees = backend_->fetchFees(lgrInfo->seq);
|
||||
assert(fees);
|
||||
|
||||
std::string range = std::to_string(ledgerRange->minSequence) + "-" +
|
||||
std::to_string(ledgerRange->maxSequence);
|
||||
|
||||
auto pubMsg = getLedgerPubMessage(*lgrInfo, *fees, range, 0);
|
||||
pubMsg.erase("txn_count");
|
||||
pubMsg.erase("type");
|
||||
return pubMsg;
|
||||
}
|
||||
|
||||
void
|
||||
@@ -81,7 +124,7 @@ SubscriptionManager::sendAll(
|
||||
}
|
||||
else
|
||||
{
|
||||
session->publishToStream(pubMsg);
|
||||
session->send(pubMsg);
|
||||
++it;
|
||||
}
|
||||
}
|
||||
@@ -94,23 +137,11 @@ SubscriptionManager::pubLedger(
|
||||
std::string const& ledgerRange,
|
||||
std::uint32_t txnCount)
|
||||
{
|
||||
boost::json::object pubMsg;
|
||||
|
||||
pubMsg["type"] = "ledgerClosed";
|
||||
pubMsg["ledger_index"] = lgrInfo.seq;
|
||||
pubMsg["ledger_hash"] = to_string(lgrInfo.hash);
|
||||
pubMsg["ledger_time"] = lgrInfo.closeTime.time_since_epoch().count();
|
||||
|
||||
pubMsg["fee_ref"] = RPC::toBoostJson(fees.units.jsonClipped());
|
||||
pubMsg["fee_base"] = RPC::toBoostJson(fees.base.jsonClipped());
|
||||
pubMsg["reserve_base"] =
|
||||
RPC::toBoostJson(fees.accountReserve(0).jsonClipped());
|
||||
pubMsg["reserve_inc"] = RPC::toBoostJson(fees.increment.jsonClipped());
|
||||
|
||||
pubMsg["validated_ledgers"] = ledgerRange;
|
||||
pubMsg["txn_count"] = txnCount;
|
||||
|
||||
sendAll(boost::json::serialize(pubMsg), streamSubscribers_[Ledgers]);
|
||||
std::cout << "publishing ledger" << std::endl;
|
||||
sendAll(
|
||||
boost::json::serialize(
|
||||
getLedgerPubMessage(lgrInfo, fees, ledgerRange, txnCount)),
|
||||
streamSubscribers_[Ledgers]);
|
||||
}
|
||||
|
||||
void
|
||||
@@ -123,6 +154,7 @@ SubscriptionManager::pubTransaction(
|
||||
pubObj["transaction"] = RPC::toJson(*tx);
|
||||
pubObj["meta"] = RPC::toJson(*meta);
|
||||
std::string pubMsg{boost::json::serialize(pubObj)};
|
||||
std::cout << "publishing txn" << std::endl;
|
||||
sendAll(pubMsg, streamSubscribers_[Transactions]);
|
||||
|
||||
auto journal = ripple::debugLog();
|
||||
@@ -171,7 +203,7 @@ void
|
||||
SubscriptionManager::forwardProposedTransaction(
|
||||
boost::json::object const& response)
|
||||
{
|
||||
std::string pubMsg{boost::json::serialize(pubMsg)};
|
||||
std::string pubMsg{boost::json::serialize(response)};
|
||||
sendAll(pubMsg, streamSubscribers_[TransactionsProposed]);
|
||||
|
||||
auto transaction = response.at("transaction").as_object();
|
||||
|
||||
@@ -42,15 +42,22 @@ class SubscriptionManager
|
||||
std::unordered_map<ripple::AccountID, subscriptions>
|
||||
accountProposedSubscribers_;
|
||||
std::unordered_map<ripple::Book, subscriptions> bookSubscribers_;
|
||||
std::shared_ptr<Backend::BackendInterface> backend_;
|
||||
|
||||
public:
|
||||
static std::shared_ptr<SubscriptionManager>
|
||||
make_SubscriptionManager()
|
||||
make_SubscriptionManager(
|
||||
std::shared_ptr<Backend::BackendInterface> const& b)
|
||||
{
|
||||
return std::make_shared<SubscriptionManager>();
|
||||
return std::make_shared<SubscriptionManager>(b);
|
||||
}
|
||||
|
||||
void
|
||||
SubscriptionManager(std::shared_ptr<Backend::BackendInterface> const& b)
|
||||
: backend_(b)
|
||||
{
|
||||
}
|
||||
|
||||
boost::json::object
|
||||
subLedger(std::shared_ptr<WsBase>& session);
|
||||
|
||||
void
|
||||
|
||||
@@ -66,7 +66,7 @@ class WsBase
|
||||
public:
|
||||
// Send, that enables SubscriptionManager to publish to clients
|
||||
virtual void
|
||||
publishToStream(std::string const& msg) = 0;
|
||||
send(std::string const& msg) = 0;
|
||||
|
||||
virtual ~WsBase()
|
||||
{
|
||||
@@ -97,7 +97,6 @@ class WsSession : public WsBase,
|
||||
using std::enable_shared_from_this<WsSession<Derived>>::shared_from_this;
|
||||
|
||||
boost::beast::flat_buffer buffer_;
|
||||
std::string responseBuffer_;
|
||||
|
||||
std::shared_ptr<BackendInterface> backend_;
|
||||
// has to be a weak ptr because SubscriptionManager maintains collections
|
||||
@@ -107,6 +106,7 @@ class WsSession : public WsBase,
|
||||
std::shared_ptr<ETLLoadBalancer> balancer_;
|
||||
DOSGuard& dosGuard_;
|
||||
std::mutex mtx_;
|
||||
std::queue<std::string> messages_;
|
||||
|
||||
public:
|
||||
explicit WsSession(
|
||||
@@ -135,33 +135,37 @@ public:
|
||||
}
|
||||
|
||||
void
|
||||
publishToStream(std::string const& msg)
|
||||
sendNext()
|
||||
{
|
||||
// msg might not stick around for as long as it takes to write, so we
|
||||
// make a copy and capture it in the lambda
|
||||
auto msgSaved = std::make_unique<std::string>(msg);
|
||||
|
||||
std::lock_guard<std::mutex> lck(mtx_);
|
||||
derived().ws().text(derived().ws().got_text());
|
||||
// we send from SubscriptionManager as well as from on_read
|
||||
derived().ws().async_write(
|
||||
boost::asio::buffer(*msgSaved),
|
||||
[m = std::move(msgSaved), shared = shared_from_this()](
|
||||
auto ec, size_t size) {
|
||||
boost::asio::buffer(messages_.front()),
|
||||
[shared = shared_from_this()](auto ec, size_t size) {
|
||||
if (ec)
|
||||
return shared->wsFail(ec, "publishToStream");
|
||||
size_t left = 0;
|
||||
{
|
||||
std::lock_guard<std::mutex> lck(shared->mtx_);
|
||||
shared->messages_.pop();
|
||||
left = shared->messages_.size();
|
||||
}
|
||||
if (left > 0)
|
||||
shared->sendNext();
|
||||
});
|
||||
}
|
||||
|
||||
void
|
||||
sendResponse(boost::json::object const& object)
|
||||
send(std::string const& msg)
|
||||
{
|
||||
std::lock_guard<std::mutex> lck(mtx_);
|
||||
responseBuffer_ = boost::json::serialize(object);
|
||||
derived().ws().text(derived().ws().got_text());
|
||||
derived().ws().async_write(
|
||||
boost::asio::buffer(responseBuffer_),
|
||||
boost::beast::bind_front_handler(
|
||||
&WsSession::on_write, this->shared_from_this()));
|
||||
size_t left = 0;
|
||||
{
|
||||
std::lock_guard<std::mutex> lck(mtx_);
|
||||
messages_.push(msg);
|
||||
left = messages_.size();
|
||||
}
|
||||
// if the queue was previously empty, start the send chain
|
||||
if (left == 1)
|
||||
sendNext();
|
||||
}
|
||||
|
||||
void
|
||||
@@ -200,6 +204,8 @@ public:
|
||||
do_read()
|
||||
{
|
||||
std::lock_guard<std::mutex> lck{mtx_};
|
||||
// Clear the buffer
|
||||
buffer_.consume(buffer_.size());
|
||||
// Read a message into our buffer
|
||||
derived().ws().async_read(
|
||||
buffer_,
|
||||
@@ -225,6 +231,9 @@ public:
|
||||
response["error"] = "Too many requests. Slow down";
|
||||
else
|
||||
{
|
||||
auto sendError = [this](auto error) {
|
||||
send(boost::json::serialize(RPC::make_error(error)));
|
||||
};
|
||||
try
|
||||
{
|
||||
boost::json::value raw = boost::json::parse(msg);
|
||||
@@ -234,8 +243,7 @@ public:
|
||||
{
|
||||
auto range = backend_->fetchLedgerRange();
|
||||
if (!range)
|
||||
return sendResponse(
|
||||
RPC::make_error(RPC::Error::rpcNOT_READY));
|
||||
return sendError(RPC::Error::rpcNOT_READY);
|
||||
|
||||
std::optional<RPC::Context> context = RPC::make_WsContext(
|
||||
request,
|
||||
@@ -246,8 +254,7 @@ public:
|
||||
*range);
|
||||
|
||||
if (!context)
|
||||
return sendResponse(
|
||||
RPC::make_error(RPC::Error::rpcBAD_SYNTAX));
|
||||
return sendError(RPC::Error::rpcBAD_SYNTAX);
|
||||
|
||||
auto id =
|
||||
request.contains("id") ? request.at("id") : nullptr;
|
||||
@@ -270,14 +277,13 @@ public:
|
||||
}
|
||||
else
|
||||
{
|
||||
response = std::get<boost::json::object>(v);
|
||||
result = std::get<boost::json::object>(v);
|
||||
}
|
||||
}
|
||||
catch (Backend::DatabaseTimeout const& t)
|
||||
{
|
||||
BOOST_LOG_TRIVIAL(error) << __func__ << " Database timeout";
|
||||
return sendResponse(
|
||||
RPC::make_error(RPC::Error::rpcNOT_READY));
|
||||
return sendError(RPC::Error::rpcNOT_READY);
|
||||
}
|
||||
}
|
||||
catch (std::exception const& e)
|
||||
@@ -285,27 +291,13 @@ public:
|
||||
BOOST_LOG_TRIVIAL(error)
|
||||
<< __func__ << " caught exception : " << e.what();
|
||||
|
||||
return sendResponse(RPC::make_error(RPC::Error::rpcINTERNAL));
|
||||
return sendError(RPC::Error::rpcINTERNAL);
|
||||
}
|
||||
}
|
||||
BOOST_LOG_TRIVIAL(trace)
|
||||
<< __func__ << " : " << boost::json::serialize(response);
|
||||
|
||||
sendResponse(response);
|
||||
}
|
||||
|
||||
void
|
||||
on_write(boost::beast::error_code ec, std::size_t bytes_transferred)
|
||||
{
|
||||
boost::ignore_unused(bytes_transferred);
|
||||
|
||||
if (ec)
|
||||
return wsFail(ec, "write");
|
||||
|
||||
// Clear the buffer
|
||||
buffer_.consume(buffer_.size());
|
||||
responseBuffer_.clear();
|
||||
// Do another read
|
||||
send(boost::json::serialize(response));
|
||||
do_read();
|
||||
}
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user