refactor to make WsSession thread safe

This commit is contained in:
CJ Cobb
2021-08-30 14:21:16 -04:00
parent 5427467ce2
commit ce9a2af33c
6 changed files with 176 additions and 79 deletions

View File

@@ -32,11 +32,11 @@
#include <functional>
#include <iostream>
#include <memory>
#include <webserver/Listener.h>
#include <sstream>
#include <string>
#include <thread>
#include <vector>
#include <webserver/Listener.h>
std::optional<boost::json::object>
parse_config(const char* filename)
@@ -149,7 +149,7 @@ main(int argc, char* argv[])
// Manages clients subscribed to streams
std::shared_ptr<SubscriptionManager> subscriptions{
SubscriptionManager::make_SubscriptionManager()};
SubscriptionManager::make_SubscriptionManager(backend)};
// Tracks which ledgers have been validated by the
// network

View File

@@ -31,7 +31,7 @@ validateStreams(boost::json::object const& request)
return OK;
}
void
boost::json::object
subscribeToStreams(
boost::json::object const& request,
std::shared_ptr<WsBase> session,
@@ -39,12 +39,13 @@ subscribeToStreams(
{
boost::json::array const& streams = request.at("streams").as_array();
boost::json::object response;
for (auto const& stream : streams)
{
std::string s = stream.as_string().c_str();
if (s == "ledger")
manager.subLedger(session);
response = manager.subLedger(session);
else if (s == "transactions")
manager.subTransactions(session);
else if (s == "transactions_proposed")
@@ -52,6 +53,7 @@ subscribeToStreams(
else
assert(false);
}
return response;
}
void
@@ -320,8 +322,10 @@ doSubscribe(Context const& context)
snapshot = std::move(snap);
}
boost::json::object response;
if (request.contains("streams"))
subscribeToStreams(request, context.session, *context.subscriptions);
response = subscribeToStreams(
request, context.session, *context.subscriptions);
if (request.contains("accounts"))
subscribeToAccounts(request, context.session, *context.subscriptions);
@@ -333,7 +337,6 @@ doSubscribe(Context const& context)
if (request.contains("books"))
subscribeToBooks(books, context.session, *context.subscriptions);
boost::json::object response = {{"status", "success"}};
if (snapshot.size())
response["offers"] = snapshot;
return response;

View File

@@ -2,12 +2,55 @@
#include <webserver/SubscriptionManager.h>
#include <webserver/WsBase.h>
void
boost::json::object
getLedgerPubMessage(
ripple::LedgerInfo const& lgrInfo,
ripple::Fees const& fees,
std::string const& ledgerRange,
std::uint32_t txnCount)
{
boost::json::object pubMsg;
pubMsg["type"] = "ledgerClosed";
pubMsg["ledger_index"] = lgrInfo.seq;
pubMsg["ledger_hash"] = to_string(lgrInfo.hash);
pubMsg["ledger_time"] = lgrInfo.closeTime.time_since_epoch().count();
pubMsg["fee_ref"] = RPC::toBoostJson(fees.units.jsonClipped());
pubMsg["fee_base"] = RPC::toBoostJson(fees.base.jsonClipped());
pubMsg["reserve_base"] =
RPC::toBoostJson(fees.accountReserve(0).jsonClipped());
pubMsg["reserve_inc"] = RPC::toBoostJson(fees.increment.jsonClipped());
pubMsg["validated_ledgers"] = ledgerRange;
pubMsg["txn_count"] = txnCount;
return pubMsg;
}
boost::json::object
SubscriptionManager::subLedger(std::shared_ptr<WsBase>& session)
{
{
std::lock_guard lk(m_);
streamSubscribers_[Ledgers].emplace(session);
}
auto ledgerRange = backend_->fetchLedgerRange();
assert(ledgerRange);
auto lgrInfo = backend_->fetchLedgerBySequence(ledgerRange->maxSequence);
assert(lgrInfo);
std::optional<ripple::Fees> fees;
fees = backend_->fetchFees(lgrInfo->seq);
assert(fees);
std::string range = std::to_string(ledgerRange->minSequence) + "-" +
std::to_string(ledgerRange->maxSequence);
auto pubMsg = getLedgerPubMessage(*lgrInfo, *fees, range, 0);
pubMsg.erase("txn_count");
pubMsg.erase("type");
return pubMsg;
}
void
SubscriptionManager::unsubLedger(std::shared_ptr<WsBase>& session)
@@ -81,7 +124,7 @@ SubscriptionManager::sendAll(
}
else
{
session->publishToStream(pubMsg);
session->send(pubMsg);
++it;
}
}
@@ -94,23 +137,11 @@ SubscriptionManager::pubLedger(
std::string const& ledgerRange,
std::uint32_t txnCount)
{
boost::json::object pubMsg;
pubMsg["type"] = "ledgerClosed";
pubMsg["ledger_index"] = lgrInfo.seq;
pubMsg["ledger_hash"] = to_string(lgrInfo.hash);
pubMsg["ledger_time"] = lgrInfo.closeTime.time_since_epoch().count();
pubMsg["fee_ref"] = RPC::toBoostJson(fees.units.jsonClipped());
pubMsg["fee_base"] = RPC::toBoostJson(fees.base.jsonClipped());
pubMsg["reserve_base"] =
RPC::toBoostJson(fees.accountReserve(0).jsonClipped());
pubMsg["reserve_inc"] = RPC::toBoostJson(fees.increment.jsonClipped());
pubMsg["validated_ledgers"] = ledgerRange;
pubMsg["txn_count"] = txnCount;
sendAll(boost::json::serialize(pubMsg), streamSubscribers_[Ledgers]);
std::cout << "publishing ledger" << std::endl;
sendAll(
boost::json::serialize(
getLedgerPubMessage(lgrInfo, fees, ledgerRange, txnCount)),
streamSubscribers_[Ledgers]);
}
void
@@ -123,6 +154,7 @@ SubscriptionManager::pubTransaction(
pubObj["transaction"] = RPC::toJson(*tx);
pubObj["meta"] = RPC::toJson(*meta);
std::string pubMsg{boost::json::serialize(pubObj)};
std::cout << "publishing txn" << std::endl;
sendAll(pubMsg, streamSubscribers_[Transactions]);
auto journal = ripple::debugLog();
@@ -171,7 +203,7 @@ void
SubscriptionManager::forwardProposedTransaction(
boost::json::object const& response)
{
std::string pubMsg{boost::json::serialize(pubMsg)};
std::string pubMsg{boost::json::serialize(response)};
sendAll(pubMsg, streamSubscribers_[TransactionsProposed]);
auto transaction = response.at("transaction").as_object();

View File

@@ -42,15 +42,22 @@ class SubscriptionManager
std::unordered_map<ripple::AccountID, subscriptions>
accountProposedSubscribers_;
std::unordered_map<ripple::Book, subscriptions> bookSubscribers_;
std::shared_ptr<Backend::BackendInterface> backend_;
public:
static std::shared_ptr<SubscriptionManager>
make_SubscriptionManager()
make_SubscriptionManager(
std::shared_ptr<Backend::BackendInterface> const& b)
{
return std::make_shared<SubscriptionManager>();
return std::make_shared<SubscriptionManager>(b);
}
void
SubscriptionManager(std::shared_ptr<Backend::BackendInterface> const& b)
: backend_(b)
{
}
boost::json::object
subLedger(std::shared_ptr<WsBase>& session);
void

View File

@@ -66,7 +66,7 @@ class WsBase
public:
// Send, that enables SubscriptionManager to publish to clients
virtual void
publishToStream(std::string const& msg) = 0;
send(std::string const& msg) = 0;
virtual ~WsBase()
{
@@ -97,7 +97,6 @@ class WsSession : public WsBase,
using std::enable_shared_from_this<WsSession<Derived>>::shared_from_this;
boost::beast::flat_buffer buffer_;
std::string responseBuffer_;
std::shared_ptr<BackendInterface> backend_;
// has to be a weak ptr because SubscriptionManager maintains collections
@@ -107,6 +106,7 @@ class WsSession : public WsBase,
std::shared_ptr<ETLLoadBalancer> balancer_;
DOSGuard& dosGuard_;
std::mutex mtx_;
std::queue<std::string> messages_;
public:
explicit WsSession(
@@ -135,33 +135,37 @@ public:
}
void
publishToStream(std::string const& msg)
sendNext()
{
// msg might not stick around for as long as it takes to write, so we
// make a copy and capture it in the lambda
auto msgSaved = std::make_unique<std::string>(msg);
std::lock_guard<std::mutex> lck(mtx_);
derived().ws().text(derived().ws().got_text());
// we send from SubscriptionManager as well as from on_read
derived().ws().async_write(
boost::asio::buffer(*msgSaved),
[m = std::move(msgSaved), shared = shared_from_this()](
auto ec, size_t size) {
boost::asio::buffer(messages_.front()),
[shared = shared_from_this()](auto ec, size_t size) {
if (ec)
return shared->wsFail(ec, "publishToStream");
size_t left = 0;
{
std::lock_guard<std::mutex> lck(shared->mtx_);
shared->messages_.pop();
left = shared->messages_.size();
}
if (left > 0)
shared->sendNext();
});
}
void
sendResponse(boost::json::object const& object)
send(std::string const& msg)
{
size_t left = 0;
{
std::lock_guard<std::mutex> lck(mtx_);
responseBuffer_ = boost::json::serialize(object);
derived().ws().text(derived().ws().got_text());
derived().ws().async_write(
boost::asio::buffer(responseBuffer_),
boost::beast::bind_front_handler(
&WsSession::on_write, this->shared_from_this()));
messages_.push(msg);
left = messages_.size();
}
// if the queue was previously empty, start the send chain
if (left == 1)
sendNext();
}
void
@@ -200,6 +204,8 @@ public:
do_read()
{
std::lock_guard<std::mutex> lck{mtx_};
// Clear the buffer
buffer_.consume(buffer_.size());
// Read a message into our buffer
derived().ws().async_read(
buffer_,
@@ -225,6 +231,9 @@ public:
response["error"] = "Too many requests. Slow down";
else
{
auto sendError = [this](auto error) {
send(boost::json::serialize(RPC::make_error(error)));
};
try
{
boost::json::value raw = boost::json::parse(msg);
@@ -234,8 +243,7 @@ public:
{
auto range = backend_->fetchLedgerRange();
if (!range)
return sendResponse(
RPC::make_error(RPC::Error::rpcNOT_READY));
return sendError(RPC::Error::rpcNOT_READY);
std::optional<RPC::Context> context = RPC::make_WsContext(
request,
@@ -246,8 +254,7 @@ public:
*range);
if (!context)
return sendResponse(
RPC::make_error(RPC::Error::rpcBAD_SYNTAX));
return sendError(RPC::Error::rpcBAD_SYNTAX);
auto id =
request.contains("id") ? request.at("id") : nullptr;
@@ -270,14 +277,13 @@ public:
}
else
{
response = std::get<boost::json::object>(v);
result = std::get<boost::json::object>(v);
}
}
catch (Backend::DatabaseTimeout const& t)
{
BOOST_LOG_TRIVIAL(error) << __func__ << " Database timeout";
return sendResponse(
RPC::make_error(RPC::Error::rpcNOT_READY));
return sendError(RPC::Error::rpcNOT_READY);
}
}
catch (std::exception const& e)
@@ -285,27 +291,13 @@ public:
BOOST_LOG_TRIVIAL(error)
<< __func__ << " caught exception : " << e.what();
return sendResponse(RPC::make_error(RPC::Error::rpcINTERNAL));
return sendError(RPC::Error::rpcINTERNAL);
}
}
BOOST_LOG_TRIVIAL(trace)
<< __func__ << " : " << boost::json::serialize(response);
sendResponse(response);
}
void
on_write(boost::beast::error_code ec, std::size_t bytes_transferred)
{
boost::ignore_unused(bytes_transferred);
if (ec)
return wsFail(ec, "write");
// Clear the buffer
buffer_.consume(buffer_.size());
responseBuffer_.clear();
// Do another read
send(boost::json::serialize(response));
do_read();
}
};

73
test.py
View File

@@ -768,6 +768,7 @@ async def ledger_range(ip, port):
idx = rng.find("-")
return (int(rng[0:idx]),int(rng[idx+1:-1]))
res = res["result"]
return (res["ledger_index_min"],res["ledger_index_max"])
except websockets.exceptions.connectionclosederror as e:
print(e)
@@ -813,18 +814,78 @@ async def subscribe(ip, port):
address = 'ws://' + str(ip) + ':' + str(port)
try:
async with websockets.connect(address) as ws:
await ws.send(json.dumps({"command":"subscribe","streams":["ledger","transactions"],"books":[{"snapshot":True,"taker_pays":{"currency":"XRP"},"taker_gets":{"currency":"USD","issuer":"rhub8VRN55s94qWKDv6jmDy1pUykJzF3wq"}}]}))
#await ws.send(json.dumps({"command":"subscribe","streams":["ledger"]}))
await ws.send(json.dumps({"command":"server_info"}));
print(json.loads(await ws.recv()))
await ws.send(json.dumps({"command":"server_info"}));
print(json.loads(await ws.recv()))
await ws.send(json.dumps({"command":"server_info"}));
await ws.send(json.dumps({"command":"server_info"}));
await ws.send(json.dumps({"command":"server_info"}));
print(json.loads(await ws.recv()))
print(json.loads(await ws.recv()))
print(json.loads(await ws.recv()))
await ws.send(json.dumps({"command":"subscribe","streams":["ledger","transactions","transactions_proposed"],"books":[{"snapshot":True,"taker_pays":{"currency":"XRP"},"taker_gets":{"currency":"USD","issuer":"rhub8VRN55s94qWKDv6jmDy1pUykJzF3wq"}}]}))
#await ws.send(json.dumps({"command":"subscribe","streams":["manifests"]}))
while True:
res = json.loads(await ws.recv())
print(json.dumps(res,indent=4,sort_keys=True))
except websockets.exceptions.connectionclosederror as e:
print(e)
async def verifySubscribe(ip,clioPort,ripdPort):
clioAddress = 'ws://' + str(ip) + ':' + str(clioPort)
ripdAddress = 'ws://' + str(ip) + ':' + str(ripdPort)
ripdMsgs = []
clioMsgs = []
try:
async with websockets.connect(clioAddress) as ws1:
async with websockets.connect(ripdAddress) as ws2:
await ws1.send(json.dumps({"command":"server_info"}))
res = json.loads(await ws1.recv())
print(res)
start = int(res["result"]["info"]["complete_ledgers"].split("-")[1])
end = start + 10
streams = ["ledger","transactions"]
await ws1.send(json.dumps({"command":"subscribe","streams":streams})),
await ws2.send(json.dumps({"command":"subscribe","streams":streams}))
res1 = json.loads(await ws1.recv())["result"]
res2 = json.loads(await ws2.recv())["result"]
assert("validated_ledgers" in res1 and "validated_ledgers" in res2)
res1["validated_ledgers"] = ""
res2["validated_ledgers"] = ""
print(res1)
print(res2)
assert(json.dumps(res1,sort_keys=True) == json.dumps(res2,sort_keys=True))
while True:
print("getting second message")
res1 = json.loads(await ws1.recv())
res2 = json.loads(await ws2.recv())
print("got second message")
if "type" in res1 and res1["type"] == "ledgerClosed":
if int(res1["ledger_index"]) >= end:
print("breaking")
break
assert("validated_ledgers" in res1 and "validated_ledgers" in res2)
res1["validated_ledgers"] = ""
res2["validated_ledgers"] = ""
print(res1)
print(res2)
ripdMsgs.append(res1)
clioMsgs.append(res2)
#assert(json.dumps(res1,sort_keys=True) == json.dumps(res2,sort_keys=True))
assert(ripdMsgs == clioMsgs)
except websockets.exceptions.connectionclosederror as e:
print(e)
parser = argparse.ArgumentParser(description='test script for xrpl-reporting')
parser.add_argument('action', choices=["account_info", "tx", "txs","account_tx", "account_tx_full","ledger_data", "ledger_data_full", "book_offers","ledger","ledger_range","ledger_entry", "ledgers", "ledger_entries","account_txs","account_infos","account_txs_full","book_offerses","ledger_diff","perf","fee","server_info", "gaps","subscribe"])
parser.add_argument('action', choices=["account_info", "tx", "txs","account_tx", "account_tx_full","ledger_data", "ledger_data_full", "book_offers","ledger","ledger_range","ledger_entry", "ledgers", "ledger_entries","account_txs","account_infos","account_txs_full","book_offerses","ledger_diff","perf","fee","server_info", "gaps","subscribe","verify_subscribe"])
parser.add_argument('--ip', default='127.0.0.1')
parser.add_argument('--port', default='8080')
@@ -836,8 +897,8 @@ parser.add_argument('--taker_pays_issuer',default='rvYAfWj5gh67oV6fW32ZzP3Aw4Eub
parser.add_argument('--taker_pays_currency',default='USD')
parser.add_argument('--taker_gets_issuer')
parser.add_argument('--taker_gets_currency',default='XRP')
parser.add_argument('--p2pIp', default='s2.ripple.com')
parser.add_argument('--p2pPort', default='51233')
parser.add_argument('--p2pIp', default='127.0.0.1')
parser.add_argument('--p2pPort', default='6006')
parser.add_argument('--verify',default=False)
parser.add_argument('--binary',default=True)
parser.add_argument('--forward',default=False)
@@ -891,6 +952,8 @@ def run(args):
print(missing)
elif args.action == "subscribe":
asyncio.get_event_loop().run_until_complete(subscribe(args.ip,args.port))
elif args.action == "verify_subscribe":
asyncio.get_event_loop().run_until_complete(verifySubscribe(args.ip,args.port,args.p2pPort))
elif args.action == "account_info":
res1 = asyncio.get_event_loop().run_until_complete(
account_info(args.ip, args.port, args.account, args.ledger, args.binary))