refactor to make WsSession thread safe

This commit is contained in:
CJ Cobb
2021-08-30 14:21:16 -04:00
parent 5427467ce2
commit ce9a2af33c
6 changed files with 176 additions and 79 deletions

View File

@@ -32,11 +32,11 @@
#include <functional> #include <functional>
#include <iostream> #include <iostream>
#include <memory> #include <memory>
#include <webserver/Listener.h>
#include <sstream> #include <sstream>
#include <string> #include <string>
#include <thread> #include <thread>
#include <vector> #include <vector>
#include <webserver/Listener.h>
std::optional<boost::json::object> std::optional<boost::json::object>
parse_config(const char* filename) parse_config(const char* filename)
@@ -149,7 +149,7 @@ main(int argc, char* argv[])
// Manages clients subscribed to streams // Manages clients subscribed to streams
std::shared_ptr<SubscriptionManager> subscriptions{ std::shared_ptr<SubscriptionManager> subscriptions{
SubscriptionManager::make_SubscriptionManager()}; SubscriptionManager::make_SubscriptionManager(backend)};
// Tracks which ledgers have been validated by the // Tracks which ledgers have been validated by the
// network // network

View File

@@ -31,7 +31,7 @@ validateStreams(boost::json::object const& request)
return OK; return OK;
} }
void boost::json::object
subscribeToStreams( subscribeToStreams(
boost::json::object const& request, boost::json::object const& request,
std::shared_ptr<WsBase> session, std::shared_ptr<WsBase> session,
@@ -39,12 +39,13 @@ subscribeToStreams(
{ {
boost::json::array const& streams = request.at("streams").as_array(); boost::json::array const& streams = request.at("streams").as_array();
boost::json::object response;
for (auto const& stream : streams) for (auto const& stream : streams)
{ {
std::string s = stream.as_string().c_str(); std::string s = stream.as_string().c_str();
if (s == "ledger") if (s == "ledger")
manager.subLedger(session); response = manager.subLedger(session);
else if (s == "transactions") else if (s == "transactions")
manager.subTransactions(session); manager.subTransactions(session);
else if (s == "transactions_proposed") else if (s == "transactions_proposed")
@@ -52,6 +53,7 @@ subscribeToStreams(
else else
assert(false); assert(false);
} }
return response;
} }
void void
@@ -320,8 +322,10 @@ doSubscribe(Context const& context)
snapshot = std::move(snap); snapshot = std::move(snap);
} }
boost::json::object response;
if (request.contains("streams")) if (request.contains("streams"))
subscribeToStreams(request, context.session, *context.subscriptions); response = subscribeToStreams(
request, context.session, *context.subscriptions);
if (request.contains("accounts")) if (request.contains("accounts"))
subscribeToAccounts(request, context.session, *context.subscriptions); subscribeToAccounts(request, context.session, *context.subscriptions);
@@ -333,7 +337,6 @@ doSubscribe(Context const& context)
if (request.contains("books")) if (request.contains("books"))
subscribeToBooks(books, context.session, *context.subscriptions); subscribeToBooks(books, context.session, *context.subscriptions);
boost::json::object response = {{"status", "success"}};
if (snapshot.size()) if (snapshot.size())
response["offers"] = snapshot; response["offers"] = snapshot;
return response; return response;

View File

@@ -2,12 +2,55 @@
#include <webserver/SubscriptionManager.h> #include <webserver/SubscriptionManager.h>
#include <webserver/WsBase.h> #include <webserver/WsBase.h>
void boost::json::object
getLedgerPubMessage(
ripple::LedgerInfo const& lgrInfo,
ripple::Fees const& fees,
std::string const& ledgerRange,
std::uint32_t txnCount)
{
boost::json::object pubMsg;
pubMsg["type"] = "ledgerClosed";
pubMsg["ledger_index"] = lgrInfo.seq;
pubMsg["ledger_hash"] = to_string(lgrInfo.hash);
pubMsg["ledger_time"] = lgrInfo.closeTime.time_since_epoch().count();
pubMsg["fee_ref"] = RPC::toBoostJson(fees.units.jsonClipped());
pubMsg["fee_base"] = RPC::toBoostJson(fees.base.jsonClipped());
pubMsg["reserve_base"] =
RPC::toBoostJson(fees.accountReserve(0).jsonClipped());
pubMsg["reserve_inc"] = RPC::toBoostJson(fees.increment.jsonClipped());
pubMsg["validated_ledgers"] = ledgerRange;
pubMsg["txn_count"] = txnCount;
return pubMsg;
}
boost::json::object
SubscriptionManager::subLedger(std::shared_ptr<WsBase>& session) SubscriptionManager::subLedger(std::shared_ptr<WsBase>& session)
{
{ {
std::lock_guard lk(m_); std::lock_guard lk(m_);
streamSubscribers_[Ledgers].emplace(session); streamSubscribers_[Ledgers].emplace(session);
} }
auto ledgerRange = backend_->fetchLedgerRange();
assert(ledgerRange);
auto lgrInfo = backend_->fetchLedgerBySequence(ledgerRange->maxSequence);
assert(lgrInfo);
std::optional<ripple::Fees> fees;
fees = backend_->fetchFees(lgrInfo->seq);
assert(fees);
std::string range = std::to_string(ledgerRange->minSequence) + "-" +
std::to_string(ledgerRange->maxSequence);
auto pubMsg = getLedgerPubMessage(*lgrInfo, *fees, range, 0);
pubMsg.erase("txn_count");
pubMsg.erase("type");
return pubMsg;
}
void void
SubscriptionManager::unsubLedger(std::shared_ptr<WsBase>& session) SubscriptionManager::unsubLedger(std::shared_ptr<WsBase>& session)
@@ -81,7 +124,7 @@ SubscriptionManager::sendAll(
} }
else else
{ {
session->publishToStream(pubMsg); session->send(pubMsg);
++it; ++it;
} }
} }
@@ -94,23 +137,11 @@ SubscriptionManager::pubLedger(
std::string const& ledgerRange, std::string const& ledgerRange,
std::uint32_t txnCount) std::uint32_t txnCount)
{ {
boost::json::object pubMsg; std::cout << "publishing ledger" << std::endl;
sendAll(
pubMsg["type"] = "ledgerClosed"; boost::json::serialize(
pubMsg["ledger_index"] = lgrInfo.seq; getLedgerPubMessage(lgrInfo, fees, ledgerRange, txnCount)),
pubMsg["ledger_hash"] = to_string(lgrInfo.hash); streamSubscribers_[Ledgers]);
pubMsg["ledger_time"] = lgrInfo.closeTime.time_since_epoch().count();
pubMsg["fee_ref"] = RPC::toBoostJson(fees.units.jsonClipped());
pubMsg["fee_base"] = RPC::toBoostJson(fees.base.jsonClipped());
pubMsg["reserve_base"] =
RPC::toBoostJson(fees.accountReserve(0).jsonClipped());
pubMsg["reserve_inc"] = RPC::toBoostJson(fees.increment.jsonClipped());
pubMsg["validated_ledgers"] = ledgerRange;
pubMsg["txn_count"] = txnCount;
sendAll(boost::json::serialize(pubMsg), streamSubscribers_[Ledgers]);
} }
void void
@@ -123,6 +154,7 @@ SubscriptionManager::pubTransaction(
pubObj["transaction"] = RPC::toJson(*tx); pubObj["transaction"] = RPC::toJson(*tx);
pubObj["meta"] = RPC::toJson(*meta); pubObj["meta"] = RPC::toJson(*meta);
std::string pubMsg{boost::json::serialize(pubObj)}; std::string pubMsg{boost::json::serialize(pubObj)};
std::cout << "publishing txn" << std::endl;
sendAll(pubMsg, streamSubscribers_[Transactions]); sendAll(pubMsg, streamSubscribers_[Transactions]);
auto journal = ripple::debugLog(); auto journal = ripple::debugLog();
@@ -171,7 +203,7 @@ void
SubscriptionManager::forwardProposedTransaction( SubscriptionManager::forwardProposedTransaction(
boost::json::object const& response) boost::json::object const& response)
{ {
std::string pubMsg{boost::json::serialize(pubMsg)}; std::string pubMsg{boost::json::serialize(response)};
sendAll(pubMsg, streamSubscribers_[TransactionsProposed]); sendAll(pubMsg, streamSubscribers_[TransactionsProposed]);
auto transaction = response.at("transaction").as_object(); auto transaction = response.at("transaction").as_object();

View File

@@ -42,15 +42,22 @@ class SubscriptionManager
std::unordered_map<ripple::AccountID, subscriptions> std::unordered_map<ripple::AccountID, subscriptions>
accountProposedSubscribers_; accountProposedSubscribers_;
std::unordered_map<ripple::Book, subscriptions> bookSubscribers_; std::unordered_map<ripple::Book, subscriptions> bookSubscribers_;
std::shared_ptr<Backend::BackendInterface> backend_;
public: public:
static std::shared_ptr<SubscriptionManager> static std::shared_ptr<SubscriptionManager>
make_SubscriptionManager() make_SubscriptionManager(
std::shared_ptr<Backend::BackendInterface> const& b)
{ {
return std::make_shared<SubscriptionManager>(); return std::make_shared<SubscriptionManager>(b);
} }
void SubscriptionManager(std::shared_ptr<Backend::BackendInterface> const& b)
: backend_(b)
{
}
boost::json::object
subLedger(std::shared_ptr<WsBase>& session); subLedger(std::shared_ptr<WsBase>& session);
void void

View File

@@ -66,7 +66,7 @@ class WsBase
public: public:
// Send, that enables SubscriptionManager to publish to clients // Send, that enables SubscriptionManager to publish to clients
virtual void virtual void
publishToStream(std::string const& msg) = 0; send(std::string const& msg) = 0;
virtual ~WsBase() virtual ~WsBase()
{ {
@@ -97,7 +97,6 @@ class WsSession : public WsBase,
using std::enable_shared_from_this<WsSession<Derived>>::shared_from_this; using std::enable_shared_from_this<WsSession<Derived>>::shared_from_this;
boost::beast::flat_buffer buffer_; boost::beast::flat_buffer buffer_;
std::string responseBuffer_;
std::shared_ptr<BackendInterface> backend_; std::shared_ptr<BackendInterface> backend_;
// has to be a weak ptr because SubscriptionManager maintains collections // has to be a weak ptr because SubscriptionManager maintains collections
@@ -107,6 +106,7 @@ class WsSession : public WsBase,
std::shared_ptr<ETLLoadBalancer> balancer_; std::shared_ptr<ETLLoadBalancer> balancer_;
DOSGuard& dosGuard_; DOSGuard& dosGuard_;
std::mutex mtx_; std::mutex mtx_;
std::queue<std::string> messages_;
public: public:
explicit WsSession( explicit WsSession(
@@ -135,33 +135,37 @@ public:
} }
void void
publishToStream(std::string const& msg) sendNext()
{ {
// msg might not stick around for as long as it takes to write, so we
// make a copy and capture it in the lambda
auto msgSaved = std::make_unique<std::string>(msg);
std::lock_guard<std::mutex> lck(mtx_); std::lock_guard<std::mutex> lck(mtx_);
derived().ws().text(derived().ws().got_text());
// we send from SubscriptionManager as well as from on_read
derived().ws().async_write( derived().ws().async_write(
boost::asio::buffer(*msgSaved), boost::asio::buffer(messages_.front()),
[m = std::move(msgSaved), shared = shared_from_this()]( [shared = shared_from_this()](auto ec, size_t size) {
auto ec, size_t size) {
if (ec) if (ec)
return shared->wsFail(ec, "publishToStream"); return shared->wsFail(ec, "publishToStream");
size_t left = 0;
{
std::lock_guard<std::mutex> lck(shared->mtx_);
shared->messages_.pop();
left = shared->messages_.size();
}
if (left > 0)
shared->sendNext();
}); });
} }
void void
sendResponse(boost::json::object const& object) send(std::string const& msg)
{
size_t left = 0;
{ {
std::lock_guard<std::mutex> lck(mtx_); std::lock_guard<std::mutex> lck(mtx_);
responseBuffer_ = boost::json::serialize(object); messages_.push(msg);
derived().ws().text(derived().ws().got_text()); left = messages_.size();
derived().ws().async_write( }
boost::asio::buffer(responseBuffer_), // if the queue was previously empty, start the send chain
boost::beast::bind_front_handler( if (left == 1)
&WsSession::on_write, this->shared_from_this())); sendNext();
} }
void void
@@ -200,6 +204,8 @@ public:
do_read() do_read()
{ {
std::lock_guard<std::mutex> lck{mtx_}; std::lock_guard<std::mutex> lck{mtx_};
// Clear the buffer
buffer_.consume(buffer_.size());
// Read a message into our buffer // Read a message into our buffer
derived().ws().async_read( derived().ws().async_read(
buffer_, buffer_,
@@ -225,6 +231,9 @@ public:
response["error"] = "Too many requests. Slow down"; response["error"] = "Too many requests. Slow down";
else else
{ {
auto sendError = [this](auto error) {
send(boost::json::serialize(RPC::make_error(error)));
};
try try
{ {
boost::json::value raw = boost::json::parse(msg); boost::json::value raw = boost::json::parse(msg);
@@ -234,8 +243,7 @@ public:
{ {
auto range = backend_->fetchLedgerRange(); auto range = backend_->fetchLedgerRange();
if (!range) if (!range)
return sendResponse( return sendError(RPC::Error::rpcNOT_READY);
RPC::make_error(RPC::Error::rpcNOT_READY));
std::optional<RPC::Context> context = RPC::make_WsContext( std::optional<RPC::Context> context = RPC::make_WsContext(
request, request,
@@ -246,8 +254,7 @@ public:
*range); *range);
if (!context) if (!context)
return sendResponse( return sendError(RPC::Error::rpcBAD_SYNTAX);
RPC::make_error(RPC::Error::rpcBAD_SYNTAX));
auto id = auto id =
request.contains("id") ? request.at("id") : nullptr; request.contains("id") ? request.at("id") : nullptr;
@@ -270,14 +277,13 @@ public:
} }
else else
{ {
response = std::get<boost::json::object>(v); result = std::get<boost::json::object>(v);
} }
} }
catch (Backend::DatabaseTimeout const& t) catch (Backend::DatabaseTimeout const& t)
{ {
BOOST_LOG_TRIVIAL(error) << __func__ << " Database timeout"; BOOST_LOG_TRIVIAL(error) << __func__ << " Database timeout";
return sendResponse( return sendError(RPC::Error::rpcNOT_READY);
RPC::make_error(RPC::Error::rpcNOT_READY));
} }
} }
catch (std::exception const& e) catch (std::exception const& e)
@@ -285,27 +291,13 @@ public:
BOOST_LOG_TRIVIAL(error) BOOST_LOG_TRIVIAL(error)
<< __func__ << " caught exception : " << e.what(); << __func__ << " caught exception : " << e.what();
return sendResponse(RPC::make_error(RPC::Error::rpcINTERNAL)); return sendError(RPC::Error::rpcINTERNAL);
} }
} }
BOOST_LOG_TRIVIAL(trace) BOOST_LOG_TRIVIAL(trace)
<< __func__ << " : " << boost::json::serialize(response); << __func__ << " : " << boost::json::serialize(response);
sendResponse(response); send(boost::json::serialize(response));
}
void
on_write(boost::beast::error_code ec, std::size_t bytes_transferred)
{
boost::ignore_unused(bytes_transferred);
if (ec)
return wsFail(ec, "write");
// Clear the buffer
buffer_.consume(buffer_.size());
responseBuffer_.clear();
// Do another read
do_read(); do_read();
} }
}; };

73
test.py
View File

@@ -768,6 +768,7 @@ async def ledger_range(ip, port):
idx = rng.find("-") idx = rng.find("-")
return (int(rng[0:idx]),int(rng[idx+1:-1])) return (int(rng[0:idx]),int(rng[idx+1:-1]))
res = res["result"]
return (res["ledger_index_min"],res["ledger_index_max"]) return (res["ledger_index_min"],res["ledger_index_max"])
except websockets.exceptions.connectionclosederror as e: except websockets.exceptions.connectionclosederror as e:
print(e) print(e)
@@ -813,18 +814,78 @@ async def subscribe(ip, port):
address = 'ws://' + str(ip) + ':' + str(port) address = 'ws://' + str(ip) + ':' + str(port)
try: try:
async with websockets.connect(address) as ws: async with websockets.connect(address) as ws:
await ws.send(json.dumps({"command":"subscribe","streams":["ledger","transactions"],"books":[{"snapshot":True,"taker_pays":{"currency":"XRP"},"taker_gets":{"currency":"USD","issuer":"rhub8VRN55s94qWKDv6jmDy1pUykJzF3wq"}}]})) await ws.send(json.dumps({"command":"server_info"}));
#await ws.send(json.dumps({"command":"subscribe","streams":["ledger"]})) print(json.loads(await ws.recv()))
await ws.send(json.dumps({"command":"server_info"}));
print(json.loads(await ws.recv()))
await ws.send(json.dumps({"command":"server_info"}));
await ws.send(json.dumps({"command":"server_info"}));
await ws.send(json.dumps({"command":"server_info"}));
print(json.loads(await ws.recv()))
print(json.loads(await ws.recv()))
print(json.loads(await ws.recv()))
await ws.send(json.dumps({"command":"subscribe","streams":["ledger","transactions","transactions_proposed"],"books":[{"snapshot":True,"taker_pays":{"currency":"XRP"},"taker_gets":{"currency":"USD","issuer":"rhub8VRN55s94qWKDv6jmDy1pUykJzF3wq"}}]}))
#await ws.send(json.dumps({"command":"subscribe","streams":["manifests"]}))
while True: while True:
res = json.loads(await ws.recv()) res = json.loads(await ws.recv())
print(json.dumps(res,indent=4,sort_keys=True)) print(json.dumps(res,indent=4,sort_keys=True))
except websockets.exceptions.connectionclosederror as e: except websockets.exceptions.connectionclosederror as e:
print(e) print(e)
async def verifySubscribe(ip,clioPort,ripdPort):
clioAddress = 'ws://' + str(ip) + ':' + str(clioPort)
ripdAddress = 'ws://' + str(ip) + ':' + str(ripdPort)
ripdMsgs = []
clioMsgs = []
try:
async with websockets.connect(clioAddress) as ws1:
async with websockets.connect(ripdAddress) as ws2:
await ws1.send(json.dumps({"command":"server_info"}))
res = json.loads(await ws1.recv())
print(res)
start = int(res["result"]["info"]["complete_ledgers"].split("-")[1])
end = start + 10
streams = ["ledger","transactions"]
await ws1.send(json.dumps({"command":"subscribe","streams":streams})),
await ws2.send(json.dumps({"command":"subscribe","streams":streams}))
res1 = json.loads(await ws1.recv())["result"]
res2 = json.loads(await ws2.recv())["result"]
assert("validated_ledgers" in res1 and "validated_ledgers" in res2)
res1["validated_ledgers"] = ""
res2["validated_ledgers"] = ""
print(res1)
print(res2)
assert(json.dumps(res1,sort_keys=True) == json.dumps(res2,sort_keys=True))
while True:
print("getting second message")
res1 = json.loads(await ws1.recv())
res2 = json.loads(await ws2.recv())
print("got second message")
if "type" in res1 and res1["type"] == "ledgerClosed":
if int(res1["ledger_index"]) >= end:
print("breaking")
break
assert("validated_ledgers" in res1 and "validated_ledgers" in res2)
res1["validated_ledgers"] = ""
res2["validated_ledgers"] = ""
print(res1)
print(res2)
ripdMsgs.append(res1)
clioMsgs.append(res2)
#assert(json.dumps(res1,sort_keys=True) == json.dumps(res2,sort_keys=True))
assert(ripdMsgs == clioMsgs)
except websockets.exceptions.connectionclosederror as e:
print(e)
parser = argparse.ArgumentParser(description='test script for xrpl-reporting') parser = argparse.ArgumentParser(description='test script for xrpl-reporting')
parser.add_argument('action', choices=["account_info", "tx", "txs","account_tx", "account_tx_full","ledger_data", "ledger_data_full", "book_offers","ledger","ledger_range","ledger_entry", "ledgers", "ledger_entries","account_txs","account_infos","account_txs_full","book_offerses","ledger_diff","perf","fee","server_info", "gaps","subscribe"]) parser.add_argument('action', choices=["account_info", "tx", "txs","account_tx", "account_tx_full","ledger_data", "ledger_data_full", "book_offers","ledger","ledger_range","ledger_entry", "ledgers", "ledger_entries","account_txs","account_infos","account_txs_full","book_offerses","ledger_diff","perf","fee","server_info", "gaps","subscribe","verify_subscribe"])
parser.add_argument('--ip', default='127.0.0.1') parser.add_argument('--ip', default='127.0.0.1')
parser.add_argument('--port', default='8080') parser.add_argument('--port', default='8080')
@@ -836,8 +897,8 @@ parser.add_argument('--taker_pays_issuer',default='rvYAfWj5gh67oV6fW32ZzP3Aw4Eub
parser.add_argument('--taker_pays_currency',default='USD') parser.add_argument('--taker_pays_currency',default='USD')
parser.add_argument('--taker_gets_issuer') parser.add_argument('--taker_gets_issuer')
parser.add_argument('--taker_gets_currency',default='XRP') parser.add_argument('--taker_gets_currency',default='XRP')
parser.add_argument('--p2pIp', default='s2.ripple.com') parser.add_argument('--p2pIp', default='127.0.0.1')
parser.add_argument('--p2pPort', default='51233') parser.add_argument('--p2pPort', default='6006')
parser.add_argument('--verify',default=False) parser.add_argument('--verify',default=False)
parser.add_argument('--binary',default=True) parser.add_argument('--binary',default=True)
parser.add_argument('--forward',default=False) parser.add_argument('--forward',default=False)
@@ -891,6 +952,8 @@ def run(args):
print(missing) print(missing)
elif args.action == "subscribe": elif args.action == "subscribe":
asyncio.get_event_loop().run_until_complete(subscribe(args.ip,args.port)) asyncio.get_event_loop().run_until_complete(subscribe(args.ip,args.port))
elif args.action == "verify_subscribe":
asyncio.get_event_loop().run_until_complete(verifySubscribe(args.ip,args.port,args.p2pPort))
elif args.action == "account_info": elif args.action == "account_info":
res1 = asyncio.get_event_loop().run_until_complete( res1 = asyncio.get_event_loop().run_until_complete(
account_info(args.ip, args.port, args.account, args.ledger, args.binary)) account_info(args.ip, args.port, args.account, args.ledger, args.binary))