From ce9a2af33c21f7cc792e46386cc3b4569673b53e Mon Sep 17 00:00:00 2001 From: CJ Cobb Date: Mon, 30 Aug 2021 14:21:16 -0400 Subject: [PATCH] refactor to make WsSession thread safe --- src/main.cpp | 4 +- src/rpc/handlers/Subscribe.cpp | 11 ++-- src/webserver/SubscriptionManager.cpp | 76 ++++++++++++++++++-------- src/webserver/SubscriptionManager.h | 13 +++-- src/webserver/WsBase.h | 78 ++++++++++++--------------- test.py | 73 +++++++++++++++++++++++-- 6 files changed, 176 insertions(+), 79 deletions(-) diff --git a/src/main.cpp b/src/main.cpp index 1c313f1b..ef91d11d 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -32,11 +32,11 @@ #include #include #include -#include #include #include #include #include +#include std::optional parse_config(const char* filename) @@ -149,7 +149,7 @@ main(int argc, char* argv[]) // Manages clients subscribed to streams std::shared_ptr subscriptions{ - SubscriptionManager::make_SubscriptionManager()}; + SubscriptionManager::make_SubscriptionManager(backend)}; // Tracks which ledgers have been validated by the // network diff --git a/src/rpc/handlers/Subscribe.cpp b/src/rpc/handlers/Subscribe.cpp index de3157a0..4d79ad49 100644 --- a/src/rpc/handlers/Subscribe.cpp +++ b/src/rpc/handlers/Subscribe.cpp @@ -31,7 +31,7 @@ validateStreams(boost::json::object const& request) return OK; } -void +boost::json::object subscribeToStreams( boost::json::object const& request, std::shared_ptr session, @@ -39,12 +39,13 @@ subscribeToStreams( { boost::json::array const& streams = request.at("streams").as_array(); + boost::json::object response; for (auto const& stream : streams) { std::string s = stream.as_string().c_str(); if (s == "ledger") - manager.subLedger(session); + response = manager.subLedger(session); else if (s == "transactions") manager.subTransactions(session); else if (s == "transactions_proposed") @@ -52,6 +53,7 @@ subscribeToStreams( else assert(false); } + return response; } void @@ -320,8 +322,10 @@ doSubscribe(Context const& context) snapshot = std::move(snap); } + boost::json::object response; if (request.contains("streams")) - subscribeToStreams(request, context.session, *context.subscriptions); + response = subscribeToStreams( + request, context.session, *context.subscriptions); if (request.contains("accounts")) subscribeToAccounts(request, context.session, *context.subscriptions); @@ -333,7 +337,6 @@ doSubscribe(Context const& context) if (request.contains("books")) subscribeToBooks(books, context.session, *context.subscriptions); - boost::json::object response = {{"status", "success"}}; if (snapshot.size()) response["offers"] = snapshot; return response; diff --git a/src/webserver/SubscriptionManager.cpp b/src/webserver/SubscriptionManager.cpp index e36bbb14..0802c53c 100644 --- a/src/webserver/SubscriptionManager.cpp +++ b/src/webserver/SubscriptionManager.cpp @@ -2,11 +2,54 @@ #include #include -void +boost::json::object +getLedgerPubMessage( + ripple::LedgerInfo const& lgrInfo, + ripple::Fees const& fees, + std::string const& ledgerRange, + std::uint32_t txnCount) +{ + boost::json::object pubMsg; + + pubMsg["type"] = "ledgerClosed"; + pubMsg["ledger_index"] = lgrInfo.seq; + pubMsg["ledger_hash"] = to_string(lgrInfo.hash); + pubMsg["ledger_time"] = lgrInfo.closeTime.time_since_epoch().count(); + + pubMsg["fee_ref"] = RPC::toBoostJson(fees.units.jsonClipped()); + pubMsg["fee_base"] = RPC::toBoostJson(fees.base.jsonClipped()); + pubMsg["reserve_base"] = + RPC::toBoostJson(fees.accountReserve(0).jsonClipped()); + pubMsg["reserve_inc"] = RPC::toBoostJson(fees.increment.jsonClipped()); + + pubMsg["validated_ledgers"] = ledgerRange; + pubMsg["txn_count"] = txnCount; + return pubMsg; +} + +boost::json::object SubscriptionManager::subLedger(std::shared_ptr& session) { - std::lock_guard lk(m_); - streamSubscribers_[Ledgers].emplace(session); + { + std::lock_guard lk(m_); + streamSubscribers_[Ledgers].emplace(session); + } + auto ledgerRange = backend_->fetchLedgerRange(); + assert(ledgerRange); + auto lgrInfo = backend_->fetchLedgerBySequence(ledgerRange->maxSequence); + assert(lgrInfo); + + std::optional fees; + fees = backend_->fetchFees(lgrInfo->seq); + assert(fees); + + std::string range = std::to_string(ledgerRange->minSequence) + "-" + + std::to_string(ledgerRange->maxSequence); + + auto pubMsg = getLedgerPubMessage(*lgrInfo, *fees, range, 0); + pubMsg.erase("txn_count"); + pubMsg.erase("type"); + return pubMsg; } void @@ -81,7 +124,7 @@ SubscriptionManager::sendAll( } else { - session->publishToStream(pubMsg); + session->send(pubMsg); ++it; } } @@ -94,23 +137,11 @@ SubscriptionManager::pubLedger( std::string const& ledgerRange, std::uint32_t txnCount) { - boost::json::object pubMsg; - - pubMsg["type"] = "ledgerClosed"; - pubMsg["ledger_index"] = lgrInfo.seq; - pubMsg["ledger_hash"] = to_string(lgrInfo.hash); - pubMsg["ledger_time"] = lgrInfo.closeTime.time_since_epoch().count(); - - pubMsg["fee_ref"] = RPC::toBoostJson(fees.units.jsonClipped()); - pubMsg["fee_base"] = RPC::toBoostJson(fees.base.jsonClipped()); - pubMsg["reserve_base"] = - RPC::toBoostJson(fees.accountReserve(0).jsonClipped()); - pubMsg["reserve_inc"] = RPC::toBoostJson(fees.increment.jsonClipped()); - - pubMsg["validated_ledgers"] = ledgerRange; - pubMsg["txn_count"] = txnCount; - - sendAll(boost::json::serialize(pubMsg), streamSubscribers_[Ledgers]); + std::cout << "publishing ledger" << std::endl; + sendAll( + boost::json::serialize( + getLedgerPubMessage(lgrInfo, fees, ledgerRange, txnCount)), + streamSubscribers_[Ledgers]); } void @@ -123,6 +154,7 @@ SubscriptionManager::pubTransaction( pubObj["transaction"] = RPC::toJson(*tx); pubObj["meta"] = RPC::toJson(*meta); std::string pubMsg{boost::json::serialize(pubObj)}; + std::cout << "publishing txn" << std::endl; sendAll(pubMsg, streamSubscribers_[Transactions]); auto journal = ripple::debugLog(); @@ -171,7 +203,7 @@ void SubscriptionManager::forwardProposedTransaction( boost::json::object const& response) { - std::string pubMsg{boost::json::serialize(pubMsg)}; + std::string pubMsg{boost::json::serialize(response)}; sendAll(pubMsg, streamSubscribers_[TransactionsProposed]); auto transaction = response.at("transaction").as_object(); diff --git a/src/webserver/SubscriptionManager.h b/src/webserver/SubscriptionManager.h index 011f6c6f..be93a726 100644 --- a/src/webserver/SubscriptionManager.h +++ b/src/webserver/SubscriptionManager.h @@ -42,15 +42,22 @@ class SubscriptionManager std::unordered_map accountProposedSubscribers_; std::unordered_map bookSubscribers_; + std::shared_ptr backend_; public: static std::shared_ptr - make_SubscriptionManager() + make_SubscriptionManager( + std::shared_ptr const& b) { - return std::make_shared(); + return std::make_shared(b); } - void + SubscriptionManager(std::shared_ptr const& b) + : backend_(b) + { + } + + boost::json::object subLedger(std::shared_ptr& session); void diff --git a/src/webserver/WsBase.h b/src/webserver/WsBase.h index 75a6025f..adf0ef85 100644 --- a/src/webserver/WsBase.h +++ b/src/webserver/WsBase.h @@ -66,7 +66,7 @@ class WsBase public: // Send, that enables SubscriptionManager to publish to clients virtual void - publishToStream(std::string const& msg) = 0; + send(std::string const& msg) = 0; virtual ~WsBase() { @@ -97,7 +97,6 @@ class WsSession : public WsBase, using std::enable_shared_from_this>::shared_from_this; boost::beast::flat_buffer buffer_; - std::string responseBuffer_; std::shared_ptr backend_; // has to be a weak ptr because SubscriptionManager maintains collections @@ -107,6 +106,7 @@ class WsSession : public WsBase, std::shared_ptr balancer_; DOSGuard& dosGuard_; std::mutex mtx_; + std::queue messages_; public: explicit WsSession( @@ -135,33 +135,37 @@ public: } void - publishToStream(std::string const& msg) + sendNext() { - // msg might not stick around for as long as it takes to write, so we - // make a copy and capture it in the lambda - auto msgSaved = std::make_unique(msg); - std::lock_guard lck(mtx_); - derived().ws().text(derived().ws().got_text()); - // we send from SubscriptionManager as well as from on_read derived().ws().async_write( - boost::asio::buffer(*msgSaved), - [m = std::move(msgSaved), shared = shared_from_this()]( - auto ec, size_t size) { + boost::asio::buffer(messages_.front()), + [shared = shared_from_this()](auto ec, size_t size) { if (ec) return shared->wsFail(ec, "publishToStream"); + size_t left = 0; + { + std::lock_guard lck(shared->mtx_); + shared->messages_.pop(); + left = shared->messages_.size(); + } + if (left > 0) + shared->sendNext(); }); } + void - sendResponse(boost::json::object const& object) + send(std::string const& msg) { - std::lock_guard lck(mtx_); - responseBuffer_ = boost::json::serialize(object); - derived().ws().text(derived().ws().got_text()); - derived().ws().async_write( - boost::asio::buffer(responseBuffer_), - boost::beast::bind_front_handler( - &WsSession::on_write, this->shared_from_this())); + size_t left = 0; + { + std::lock_guard lck(mtx_); + messages_.push(msg); + left = messages_.size(); + } + // if the queue was previously empty, start the send chain + if (left == 1) + sendNext(); } void @@ -200,6 +204,8 @@ public: do_read() { std::lock_guard lck{mtx_}; + // Clear the buffer + buffer_.consume(buffer_.size()); // Read a message into our buffer derived().ws().async_read( buffer_, @@ -225,6 +231,9 @@ public: response["error"] = "Too many requests. Slow down"; else { + auto sendError = [this](auto error) { + send(boost::json::serialize(RPC::make_error(error))); + }; try { boost::json::value raw = boost::json::parse(msg); @@ -234,8 +243,7 @@ public: { auto range = backend_->fetchLedgerRange(); if (!range) - return sendResponse( - RPC::make_error(RPC::Error::rpcNOT_READY)); + return sendError(RPC::Error::rpcNOT_READY); std::optional context = RPC::make_WsContext( request, @@ -246,8 +254,7 @@ public: *range); if (!context) - return sendResponse( - RPC::make_error(RPC::Error::rpcBAD_SYNTAX)); + return sendError(RPC::Error::rpcBAD_SYNTAX); auto id = request.contains("id") ? request.at("id") : nullptr; @@ -270,14 +277,13 @@ public: } else { - response = std::get(v); + result = std::get(v); } } catch (Backend::DatabaseTimeout const& t) { BOOST_LOG_TRIVIAL(error) << __func__ << " Database timeout"; - return sendResponse( - RPC::make_error(RPC::Error::rpcNOT_READY)); + return sendError(RPC::Error::rpcNOT_READY); } } catch (std::exception const& e) @@ -285,27 +291,13 @@ public: BOOST_LOG_TRIVIAL(error) << __func__ << " caught exception : " << e.what(); - return sendResponse(RPC::make_error(RPC::Error::rpcINTERNAL)); + return sendError(RPC::Error::rpcINTERNAL); } } BOOST_LOG_TRIVIAL(trace) << __func__ << " : " << boost::json::serialize(response); - sendResponse(response); - } - - void - on_write(boost::beast::error_code ec, std::size_t bytes_transferred) - { - boost::ignore_unused(bytes_transferred); - - if (ec) - return wsFail(ec, "write"); - - // Clear the buffer - buffer_.consume(buffer_.size()); - responseBuffer_.clear(); - // Do another read + send(boost::json::serialize(response)); do_read(); } }; diff --git a/test.py b/test.py index 2f335974..b04eb9db 100755 --- a/test.py +++ b/test.py @@ -768,6 +768,7 @@ async def ledger_range(ip, port): idx = rng.find("-") return (int(rng[0:idx]),int(rng[idx+1:-1])) + res = res["result"] return (res["ledger_index_min"],res["ledger_index_max"]) except websockets.exceptions.connectionclosederror as e: print(e) @@ -813,18 +814,78 @@ async def subscribe(ip, port): address = 'ws://' + str(ip) + ':' + str(port) try: async with websockets.connect(address) as ws: - await ws.send(json.dumps({"command":"subscribe","streams":["ledger","transactions"],"books":[{"snapshot":True,"taker_pays":{"currency":"XRP"},"taker_gets":{"currency":"USD","issuer":"rhub8VRN55s94qWKDv6jmDy1pUykJzF3wq"}}]})) - #await ws.send(json.dumps({"command":"subscribe","streams":["ledger"]})) + await ws.send(json.dumps({"command":"server_info"})); + print(json.loads(await ws.recv())) + await ws.send(json.dumps({"command":"server_info"})); + print(json.loads(await ws.recv())) + await ws.send(json.dumps({"command":"server_info"})); + await ws.send(json.dumps({"command":"server_info"})); + await ws.send(json.dumps({"command":"server_info"})); + print(json.loads(await ws.recv())) + print(json.loads(await ws.recv())) + print(json.loads(await ws.recv())) + await ws.send(json.dumps({"command":"subscribe","streams":["ledger","transactions","transactions_proposed"],"books":[{"snapshot":True,"taker_pays":{"currency":"XRP"},"taker_gets":{"currency":"USD","issuer":"rhub8VRN55s94qWKDv6jmDy1pUykJzF3wq"}}]})) + #await ws.send(json.dumps({"command":"subscribe","streams":["manifests"]})) while True: res = json.loads(await ws.recv()) print(json.dumps(res,indent=4,sort_keys=True)) except websockets.exceptions.connectionclosederror as e: print(e) +async def verifySubscribe(ip,clioPort,ripdPort): + clioAddress = 'ws://' + str(ip) + ':' + str(clioPort) + ripdAddress = 'ws://' + str(ip) + ':' + str(ripdPort) + ripdMsgs = [] + clioMsgs = [] + try: + async with websockets.connect(clioAddress) as ws1: + async with websockets.connect(ripdAddress) as ws2: + await ws1.send(json.dumps({"command":"server_info"})) + res = json.loads(await ws1.recv()) + print(res) + start = int(res["result"]["info"]["complete_ledgers"].split("-")[1]) + end = start + 10 + + streams = ["ledger","transactions"] + await ws1.send(json.dumps({"command":"subscribe","streams":streams})), + await ws2.send(json.dumps({"command":"subscribe","streams":streams})) + + res1 = json.loads(await ws1.recv())["result"] + res2 = json.loads(await ws2.recv())["result"] + assert("validated_ledgers" in res1 and "validated_ledgers" in res2) + res1["validated_ledgers"] = "" + res2["validated_ledgers"] = "" + print(res1) + print(res2) + assert(json.dumps(res1,sort_keys=True) == json.dumps(res2,sort_keys=True)) + while True: + print("getting second message") + res1 = json.loads(await ws1.recv()) + res2 = json.loads(await ws2.recv()) + print("got second message") + if "type" in res1 and res1["type"] == "ledgerClosed": + if int(res1["ledger_index"]) >= end: + print("breaking") + break + assert("validated_ledgers" in res1 and "validated_ledgers" in res2) + res1["validated_ledgers"] = "" + res2["validated_ledgers"] = "" + print(res1) + print(res2) + ripdMsgs.append(res1) + clioMsgs.append(res2) + #assert(json.dumps(res1,sort_keys=True) == json.dumps(res2,sort_keys=True)) + assert(ripdMsgs == clioMsgs) + except websockets.exceptions.connectionclosederror as e: + print(e) + + + + parser = argparse.ArgumentParser(description='test script for xrpl-reporting') -parser.add_argument('action', choices=["account_info", "tx", "txs","account_tx", "account_tx_full","ledger_data", "ledger_data_full", "book_offers","ledger","ledger_range","ledger_entry", "ledgers", "ledger_entries","account_txs","account_infos","account_txs_full","book_offerses","ledger_diff","perf","fee","server_info", "gaps","subscribe"]) +parser.add_argument('action', choices=["account_info", "tx", "txs","account_tx", "account_tx_full","ledger_data", "ledger_data_full", "book_offers","ledger","ledger_range","ledger_entry", "ledgers", "ledger_entries","account_txs","account_infos","account_txs_full","book_offerses","ledger_diff","perf","fee","server_info", "gaps","subscribe","verify_subscribe"]) parser.add_argument('--ip', default='127.0.0.1') parser.add_argument('--port', default='8080') @@ -836,8 +897,8 @@ parser.add_argument('--taker_pays_issuer',default='rvYAfWj5gh67oV6fW32ZzP3Aw4Eub parser.add_argument('--taker_pays_currency',default='USD') parser.add_argument('--taker_gets_issuer') parser.add_argument('--taker_gets_currency',default='XRP') -parser.add_argument('--p2pIp', default='s2.ripple.com') -parser.add_argument('--p2pPort', default='51233') +parser.add_argument('--p2pIp', default='127.0.0.1') +parser.add_argument('--p2pPort', default='6006') parser.add_argument('--verify',default=False) parser.add_argument('--binary',default=True) parser.add_argument('--forward',default=False) @@ -891,6 +952,8 @@ def run(args): print(missing) elif args.action == "subscribe": asyncio.get_event_loop().run_until_complete(subscribe(args.ip,args.port)) + elif args.action == "verify_subscribe": + asyncio.get_event_loop().run_until_complete(verifySubscribe(args.ip,args.port,args.p2pPort)) elif args.action == "account_info": res1 = asyncio.get_event_loop().run_until_complete( account_info(args.ip, args.port, args.account, args.ledger, args.binary))