diff --git a/docs/configure-clio.md b/docs/configure-clio.md index bd6d3fe8..ecd596da 100644 --- a/docs/configure-clio.md +++ b/docs/configure-clio.md @@ -27,7 +27,7 @@ If you're running Clio and `rippled` on separate machines, in addition to uncomm 2. Open a public, unencrypted WebSocket port on your `rippled` server. -3. In the `rippled` config, change the IP specified for `secure_gateway`, under the `port_grpc` section, to the IP of your Clio server. This entry can take the form of a comma-separated list if you are running multiple Clio nodes. +3. In the `rippled` config, change the IP specified for `secure_gateway`, under the `port_grpc` and websocket server sections, to the IP of your Clio server. This entry can take the form of a comma-separated list if you are running multiple Clio nodes. ## Ledger sequence diff --git a/docs/examples/config/example-config.json b/docs/examples/config/example-config.json index 711c060e..e62cf1be 100644 --- a/docs/examples/config/example-config.json +++ b/docs/examples/config/example-config.json @@ -64,7 +64,7 @@ "admin_password": "xrp", // If local_admin is true, Clio will consider requests come from 127.0.0.1 as admin requests // It's true by default unless admin_password is set,'local_admin' : true and 'admin_password' can not be set at the same time - "local_amdin": false + "local_admin": false }, // Time in seconds for graceful shutdown. Defaults to 10 seconds. Not fully implemented yet. "graceful_period": 10.0, diff --git a/src/etl/ETLState.hpp b/src/etl/ETLState.hpp index 1fa5a4ef..91004e40 100644 --- a/src/etl/ETLState.hpp +++ b/src/etl/ETLState.hpp @@ -47,7 +47,7 @@ struct ETLState { fetchETLStateFromSource(Forward& source) noexcept { auto const serverInfoRippled = data::synchronous([&source](auto yield) { - return source.forwardToRippled({{"command", "server_info"}}, std::nullopt, yield); + return source.forwardToRippled({{"command", "server_info"}}, std::nullopt, {}, yield); }); if (serverInfoRippled) diff --git a/src/etl/LoadBalancer.cpp b/src/etl/LoadBalancer.cpp index 7e10bb4a..83c714c6 100644 --- a/src/etl/LoadBalancer.cpp +++ b/src/etl/LoadBalancer.cpp @@ -115,7 +115,10 @@ LoadBalancer::LoadBalancer( chooseForwardingSource(); }, [this]() { chooseForwardingSource(); }, - [this]() { forwardingCache_->invalidate(); } + [this]() { + if (forwardingCache_.has_value()) + forwardingCache_->invalidate(); + } ); // checking etl node validity @@ -213,6 +216,7 @@ std::optional LoadBalancer::forwardToRippled( boost::json::object const& request, std::optional const& clientIp, + bool isAdmin, boost::asio::yield_context yield ) { @@ -227,9 +231,11 @@ LoadBalancer::forwardToRippled( auto numAttempts = 0u; + auto xUserValue = isAdmin ? ADMIN_FORWARDING_X_USER_VALUE : USER_FORWARDING_X_USER_VALUE; + std::optional response; while (numAttempts < sources_.size()) { - if (auto res = sources_[sourceIdx]->forwardToRippled(request, clientIp, yield)) { + if (auto res = sources_[sourceIdx]->forwardToRippled(request, clientIp, xUserValue, yield)) { response = std::move(res); break; } diff --git a/src/etl/LoadBalancer.hpp b/src/etl/LoadBalancer.hpp index 33916c19..084f55a9 100644 --- a/src/etl/LoadBalancer.hpp +++ b/src/etl/LoadBalancer.hpp @@ -20,8 +20,8 @@ #pragma once #include "data/BackendInterface.hpp" -#include "etl/ETLHelpers.hpp" #include "etl/ETLState.hpp" +#include "etl/NetworkValidatedLedgersInterface.hpp" #include "etl/Source.hpp" #include "etl/impl/ForwardingCache.hpp" #include "feed/SubscriptionManagerInterface.hpp" @@ -44,7 +44,7 @@ #include #include #include -#include +#include #include namespace etl { @@ -68,6 +68,8 @@ private: util::Logger log_{"ETL"}; // Forwarding cache must be destroyed after sources because sources have a callback to invalidate cache std::optional forwardingCache_; + std::optional forwardingXUserValue_; + std::vector sources_; std::optional etlState_; std::uint32_t downloadRanges_ = @@ -75,6 +77,16 @@ private: std::atomic_bool hasForwardingSource_{false}; public: + /** + * @brief Value for the X-User header when forwarding admin requests + */ + static constexpr std::string_view ADMIN_FORWARDING_X_USER_VALUE = "clio_admin"; + + /** + * @brief Value for the X-User header when forwarding user requests + */ + static constexpr std::string_view USER_FORWARDING_X_USER_VALUE = "clio_user"; + /** * @brief Create an instance of the load balancer. * @@ -167,6 +179,7 @@ public: * * @param request JSON-RPC request to forward * @param clientIp The IP address of the peer, if known + * @param isAdmin Whether the request is from an admin * @param yield The coroutine context * @return Response received from rippled node as JSON object on success; nullopt on failure */ @@ -174,6 +187,7 @@ public: forwardToRippled( boost::json::object const& request, std::optional const& clientIp, + bool isAdmin, boost::asio::yield_context yield ); diff --git a/src/etl/Source.hpp b/src/etl/Source.hpp index 59aa07d7..76524558 100644 --- a/src/etl/Source.hpp +++ b/src/etl/Source.hpp @@ -37,6 +37,7 @@ #include #include #include +#include #include #include @@ -127,6 +128,7 @@ public: * * @param request The request to forward * @param forwardToRippledClientIp IP of the client forwarding this request if known + * @param xUserValue Value of the X-User header * @param yield The coroutine context * @return Response wrapped in an optional on success; nullopt otherwise */ @@ -134,6 +136,7 @@ public: forwardToRippled( boost::json::object const& request, std::optional const& forwardToRippledClientIp, + std::string_view xUserValue, boost::asio::yield_context yield ) const = 0; }; diff --git a/src/etl/impl/ForwardingSource.cpp b/src/etl/impl/ForwardingSource.cpp index 5afb2f34..4101be9f 100644 --- a/src/etl/impl/ForwardingSource.cpp +++ b/src/etl/impl/ForwardingSource.cpp @@ -34,6 +34,7 @@ #include #include #include +#include #include namespace etl::impl { @@ -55,6 +56,7 @@ std::optional ForwardingSource::forwardToRippled( boost::json::object const& request, std::optional const& forwardToRippledClientIp, + std::string_view xUserValue, boost::asio::yield_context yield ) const { @@ -64,6 +66,9 @@ ForwardingSource::forwardToRippled( {boost::beast::http::field::forwarded, fmt::format("for={}", *forwardToRippledClientIp)} ); } + + connectionBuilder.addHeader({"X-User", std::string{xUserValue}}); + auto expectedConnection = connectionBuilder.connect(yield); if (not expectedConnection) { return std::nullopt; diff --git a/src/etl/impl/ForwardingSource.hpp b/src/etl/impl/ForwardingSource.hpp index b7752773..373786de 100644 --- a/src/etl/impl/ForwardingSource.hpp +++ b/src/etl/impl/ForwardingSource.hpp @@ -28,6 +28,7 @@ #include #include #include +#include namespace etl::impl { @@ -49,6 +50,7 @@ public: * * @param request The request to forward * @param forwardToRippledClientIp IP of the client forwarding this request if known + * @param xUserValue Optional value for X-User header * @param yield The coroutine context * @return Response wrapped in an optional on success; nullopt otherwise */ @@ -56,6 +58,7 @@ public: forwardToRippled( boost::json::object const& request, std::optional const& forwardToRippledClientIp, + std::string_view xUserValue, boost::asio::yield_context yield ) const; }; diff --git a/src/etl/impl/SourceImpl.hpp b/src/etl/impl/SourceImpl.hpp index 3dac2ed8..a24ddff5 100644 --- a/src/etl/impl/SourceImpl.hpp +++ b/src/etl/impl/SourceImpl.hpp @@ -34,6 +34,7 @@ #include #include #include +#include #include #include namespace etl::impl { @@ -202,6 +203,7 @@ public: * * @param request The request to forward * @param forwardToRippledClientIp IP of the client forwarding this request if known + * @param xUserValue Optional value of the X-User header * @param yield The coroutine context * @return Response wrapped in an optional on success; nullopt otherwise */ @@ -209,10 +211,11 @@ public: forwardToRippled( boost::json::object const& request, std::optional const& forwardToRippledClientIp, + std::string_view xUserValue, boost::asio::yield_context yield ) const final { - return forwardingSource_.forwardToRippled(request, forwardToRippledClientIp, yield); + return forwardingSource_.forwardToRippled(request, forwardToRippledClientIp, xUserValue, yield); } }; diff --git a/src/rpc/common/impl/ForwardingProxy.hpp b/src/rpc/common/impl/ForwardingProxy.hpp index ab5c6b33..757f8493 100644 --- a/src/rpc/common/impl/ForwardingProxy.hpp +++ b/src/rpc/common/impl/ForwardingProxy.hpp @@ -93,7 +93,7 @@ public: auto toForward = ctx.params; toForward["command"] = ctx.method; - auto res = balancer_->forwardToRippled(toForward, ctx.clientIp, ctx.yield); + auto res = balancer_->forwardToRippled(toForward, ctx.clientIp, ctx.isAdmin, ctx.yield); if (not res) { notifyFailedToForward(ctx.method); return Result{Status{RippledError::rpcFAILED_TO_FORWARD}}; diff --git a/src/rpc/handlers/ServerInfo.hpp b/src/rpc/handlers/ServerInfo.hpp index b43a75a3..18a076c4 100644 --- a/src/rpc/handlers/ServerInfo.hpp +++ b/src/rpc/handlers/ServerInfo.hpp @@ -222,7 +222,7 @@ public: } auto const serverInfoRippled = - balancer_->forwardToRippled({{"command", "server_info"}}, ctx.clientIp, ctx.yield); + balancer_->forwardToRippled({{"command", "server_info"}}, ctx.clientIp, ctx.isAdmin, ctx.yield); if (serverInfoRippled && !serverInfoRippled->contains(JS(error))) { if (serverInfoRippled->contains(JS(result)) && diff --git a/tests/common/util/MockLoadBalancer.hpp b/tests/common/util/MockLoadBalancer.hpp index 73ce7c1d..60749bb8 100644 --- a/tests/common/util/MockLoadBalancer.hpp +++ b/tests/common/util/MockLoadBalancer.hpp @@ -40,7 +40,7 @@ struct MockLoadBalancer { MOCK_METHOD( std::optional, forwardToRippled, - (boost::json::object const&, std::optional const&, boost::asio::yield_context), + (boost::json::object const&, std::optional const&, bool, boost::asio::yield_context), (const) ); }; diff --git a/tests/common/util/MockSource.hpp b/tests/common/util/MockSource.hpp index 54121149..eda22951 100644 --- a/tests/common/util/MockSource.hpp +++ b/tests/common/util/MockSource.hpp @@ -19,7 +19,7 @@ #pragma once #include "data/BackendInterface.hpp" -#include "etl/ETLHelpers.hpp" +#include "etl/NetworkValidatedLedgersInterface.hpp" #include "etl/Source.hpp" #include "feed/SubscriptionManagerInterface.hpp" #include "util/config/Config.hpp" @@ -39,7 +39,7 @@ #include #include #include -#include +#include #include #include @@ -60,7 +60,7 @@ struct MockSource : etl::SourceBase { MOCK_METHOD( std::optional, forwardToRippled, - (boost::json::object const&, std::optional const&, boost::asio::yield_context), + (boost::json::object const&, std::optional const&, std::string_view, boost::asio::yield_context), (const, override) ); }; @@ -129,10 +129,11 @@ public: forwardToRippled( boost::json::object const& request, std::optional const& forwardToRippledClientIp, + std::string_view xUserValue, boost::asio::yield_context yield ) const override { - return mock_->forwardToRippled(request, forwardToRippledClientIp, yield); + return mock_->forwardToRippled(request, forwardToRippledClientIp, xUserValue, yield); } }; diff --git a/tests/common/util/TestWsServer.cpp b/tests/common/util/TestWsServer.cpp index dfb60dd0..4fa33261 100644 --- a/tests/common/util/TestWsServer.cpp +++ b/tests/common/util/TestWsServer.cpp @@ -32,20 +32,30 @@ #include #include #include +#include +#include +#include #include #include #include #include +#include #include +#include #include #include #include +#include namespace asio = boost::asio; namespace websocket = boost::beast::websocket; -TestWsConnection::TestWsConnection(websocket::stream wsStream) : ws_(std::move(wsStream)) +TestWsConnection::TestWsConnection( + websocket::stream wsStream, + std::vector headers +) + : ws_(std::move(wsStream)), headers_(std::move(headers)) { } @@ -83,6 +93,12 @@ TestWsConnection::close(boost::asio::yield_context yield) return std::nullopt; } +std::vector const& +TestWsConnection::headers() const +{ + return headers_; +} + TestWsServer::TestWsServer(asio::io_context& context, std::string const& host, int port) : acceptor_(context) { auto endpoint = asio::ip::tcp::endpoint(boost::asio::ip::make_address(host), port); @@ -102,13 +118,28 @@ TestWsServer::acceptConnection(asio::yield_context yield) if (errorCode) return std::unexpected{util::requests::RequestError{"Accept error", errorCode}}; + boost::beast::flat_buffer buffer; + boost::beast::http::request request; + boost::beast::http::async_read(socket, buffer, request, yield[errorCode]); + if (errorCode) + return std::unexpected{util::requests::RequestError{"Read error", errorCode}}; + std::vector headers; + std::transform(request.begin(), request.end(), std::back_inserter(headers), [](auto const& header) { + if (header.name() == boost::beast::http::field::unknown) + return util::requests::HttpHeader{header.name_string(), header.value()}; + + return util::requests::HttpHeader{header.name(), header.value()}; + }); + if (not boost::beast::websocket::is_upgrade(request)) + return std::unexpected{util::requests::RequestError{"Not a websocket request"}}; + boost::beast::websocket::stream ws(std::move(socket)); ws.set_option(websocket::stream_base::timeout::suggested(boost::beast::role_type::server)); - ws.async_accept(yield[errorCode]); + ws.async_accept(request, yield[errorCode]); if (errorCode) return std::unexpected{util::requests::RequestError{"Handshake error", errorCode}}; - return TestWsConnection(std::move(ws)); + return TestWsConnection(std::move(ws), std::move(headers)); } void diff --git a/tests/common/util/TestWsServer.hpp b/tests/common/util/TestWsServer.hpp index e258f529..15760998 100644 --- a/tests/common/util/TestWsServer.hpp +++ b/tests/common/util/TestWsServer.hpp @@ -31,15 +31,20 @@ #include #include #include +#include class TestWsConnection { boost::beast::websocket::stream ws_; + std::vector headers_; public: using SendCallback = std::function; using ReceiveCallback = std::function; - TestWsConnection(boost::beast::websocket::stream wsStream); + TestWsConnection( + boost::beast::websocket::stream wsStream, + std::vector headers + ); // returns error message if error occurs std::optional @@ -51,6 +56,9 @@ public: std::optional close(boost::asio::yield_context yield); + + std::vector const& + headers() const; }; class TestWsServer { diff --git a/tests/unit/etl/ForwardingSourceTests.cpp b/tests/unit/etl/ForwardingSourceTests.cpp index 4cd22a35..07523e60 100644 --- a/tests/unit/etl/ForwardingSourceTests.cpp +++ b/tests/unit/etl/ForwardingSourceTests.cpp @@ -27,6 +27,7 @@ #include #include +#include #include #include #include @@ -42,7 +43,7 @@ struct ForwardingSourceTests : SyncAsioContextTest { TEST_F(ForwardingSourceTests, ConnectionFailed) { runSpawn([&](boost::asio::yield_context yield) { - auto result = forwardingSource.forwardToRippled({}, {}, yield); + auto result = forwardingSource.forwardToRippled({}, {}, {}, yield); EXPECT_FALSE(result); }); } @@ -59,11 +60,34 @@ struct ForwardingSourceOperationsTests : ForwardingSourceTests { [&]() { ASSERT_FALSE(failedConnection); }(); auto connection = server_.acceptConnection(yield); - [&]() { ASSERT_TRUE(connection); }(); + [&]() { ASSERT_TRUE(connection) << connection.error().message(); }(); return std::move(connection).value(); } }; +TEST_F(ForwardingSourceOperationsTests, XUserHeader) +{ + std::string const xUserValue = "some_user"; + boost::asio::spawn(ctx, [&](boost::asio::yield_context yield) { + auto connection = serverConnection(yield); + auto headers = connection.headers(); + ASSERT_FALSE(headers.empty()); + auto it = std::find_if(headers.begin(), headers.end(), [](auto const& header) { + return std::holds_alternative(header.name) && std::get(header.name) == "X-User"; + }); + ASSERT_FALSE(it == headers.end()); + EXPECT_EQ(std::get(it->name), "X-User"); + EXPECT_EQ(it->value, xUserValue); + connection.close(yield); + }); + + runSpawn([&](boost::asio::yield_context yield) { + auto result = + forwardingSource.forwardToRippled(boost::json::parse(message_).as_object(), {}, xUserValue, yield); + EXPECT_FALSE(result); + }); +} + TEST_F(ForwardingSourceOperationsTests, ReadFailed) { boost::asio::spawn(ctx, [&](boost::asio::yield_context yield) { @@ -72,7 +96,7 @@ TEST_F(ForwardingSourceOperationsTests, ReadFailed) }); runSpawn([&](boost::asio::yield_context yield) { - auto result = forwardingSource.forwardToRippled(boost::json::parse(message_).as_object(), {}, yield); + auto result = forwardingSource.forwardToRippled(boost::json::parse(message_).as_object(), {}, {}, yield); EXPECT_FALSE(result); }); } @@ -93,7 +117,7 @@ TEST_F(ForwardingSourceOperationsTests, ParseFailed) }); runSpawn([&](boost::asio::yield_context yield) { - auto result = forwardingSource.forwardToRippled(boost::json::parse(message_).as_object(), {}, yield); + auto result = forwardingSource.forwardToRippled(boost::json::parse(message_).as_object(), {}, {}, yield); EXPECT_FALSE(result); }); } @@ -115,7 +139,7 @@ TEST_F(ForwardingSourceOperationsTests, GotNotAnObject) }); runSpawn([&](boost::asio::yield_context yield) { - auto result = forwardingSource.forwardToRippled(boost::json::parse(message_).as_object(), {}, yield); + auto result = forwardingSource.forwardToRippled(boost::json::parse(message_).as_object(), {}, {}, yield); EXPECT_FALSE(result); }); } @@ -134,7 +158,7 @@ TEST_F(ForwardingSourceOperationsTests, Success) }); runSpawn([&](boost::asio::yield_context yield) { - auto result = forwardingSource.forwardToRippled(boost::json::parse(message_).as_object(), "some_ip", yield); + auto result = forwardingSource.forwardToRippled(boost::json::parse(message_).as_object(), "some_ip", {}, yield); [&]() { ASSERT_TRUE(result); }(); auto expectedReply = reply_; expectedReply["forwarded"] = true; diff --git a/tests/unit/etl/LoadBalancerTests.cpp b/tests/unit/etl/LoadBalancerTests.cpp index 2bf9fd14..c662cab4 100644 --- a/tests/unit/etl/LoadBalancerTests.cpp +++ b/tests/unit/etl/LoadBalancerTests.cpp @@ -485,37 +485,66 @@ struct LoadBalancerForwardToRippledTests : LoadBalancerConstructorTests, SyncAsi TEST_F(LoadBalancerForwardToRippledTests, forward) { auto loadBalancer = makeLoadBalancer(); - EXPECT_CALL(sourceFactory_.sourceAt(0), forwardToRippled(request_, clientIP_, testing::_)) + EXPECT_CALL( + sourceFactory_.sourceAt(0), + forwardToRippled(request_, clientIP_, LoadBalancer::ADMIN_FORWARDING_X_USER_VALUE, testing::_) + ) .WillOnce(Return(response_)); runSpawn([&](boost::asio::yield_context yield) { - EXPECT_EQ(loadBalancer->forwardToRippled(request_, clientIP_, yield), response_); + EXPECT_EQ(loadBalancer->forwardToRippled(request_, clientIP_, true, yield), response_); + }); +} + +TEST_F(LoadBalancerForwardToRippledTests, forwardWithXUserHeader) +{ + auto loadBalancer = makeLoadBalancer(); + EXPECT_CALL( + sourceFactory_.sourceAt(0), + forwardToRippled(request_, clientIP_, LoadBalancer::USER_FORWARDING_X_USER_VALUE, testing::_) + ) + .WillOnce(Return(response_)); + + runSpawn([&](boost::asio::yield_context yield) { + EXPECT_EQ(loadBalancer->forwardToRippled(request_, clientIP_, false, yield), response_); }); } TEST_F(LoadBalancerForwardToRippledTests, source0Fails) { auto loadBalancer = makeLoadBalancer(); - EXPECT_CALL(sourceFactory_.sourceAt(0), forwardToRippled(request_, clientIP_, testing::_)) + EXPECT_CALL( + sourceFactory_.sourceAt(0), + forwardToRippled(request_, clientIP_, LoadBalancer::USER_FORWARDING_X_USER_VALUE, testing::_) + ) .WillOnce(Return(std::nullopt)); - EXPECT_CALL(sourceFactory_.sourceAt(1), forwardToRippled(request_, clientIP_, testing::_)) + EXPECT_CALL( + sourceFactory_.sourceAt(1), + forwardToRippled(request_, clientIP_, LoadBalancer::USER_FORWARDING_X_USER_VALUE, testing::_) + ) .WillOnce(Return(response_)); runSpawn([&](boost::asio::yield_context yield) { - EXPECT_EQ(loadBalancer->forwardToRippled(request_, clientIP_, yield), response_); + EXPECT_EQ(loadBalancer->forwardToRippled(request_, clientIP_, false, yield), response_); }); } TEST_F(LoadBalancerForwardToRippledTests, bothSourcesFail) { auto loadBalancer = makeLoadBalancer(); - EXPECT_CALL(sourceFactory_.sourceAt(0), forwardToRippled(request_, clientIP_, testing::_)) + EXPECT_CALL( + sourceFactory_.sourceAt(0), + forwardToRippled(request_, clientIP_, LoadBalancer::USER_FORWARDING_X_USER_VALUE, testing::_) + ) .WillOnce(Return(std::nullopt)); - EXPECT_CALL(sourceFactory_.sourceAt(1), forwardToRippled(request_, clientIP_, testing::_)) + EXPECT_CALL( + sourceFactory_.sourceAt(1), + forwardToRippled(request_, clientIP_, LoadBalancer::USER_FORWARDING_X_USER_VALUE, testing::_) + ) .WillOnce(Return(std::nullopt)); runSpawn([&](boost::asio::yield_context yield) { - EXPECT_EQ(loadBalancer->forwardToRippled(request_, clientIP_, yield), std::nullopt); + EXPECT_EQ(loadBalancer->forwardToRippled(request_, clientIP_, false, yield), std::nullopt); }); } @@ -526,15 +555,24 @@ TEST_F(LoadBalancerForwardToRippledTests, forwardingCacheEnabled) auto const request = boost::json::object{{"command", "server_info"}}; - EXPECT_CALL(sourceFactory_.sourceAt(0), forwardToRippled(request, clientIP_, testing::_)) + EXPECT_CALL( + sourceFactory_.sourceAt(0), + forwardToRippled(request, clientIP_, LoadBalancer::USER_FORWARDING_X_USER_VALUE, testing::_) + ) .WillOnce(Return(response_)); runSpawn([&](boost::asio::yield_context yield) { - EXPECT_EQ(loadBalancer->forwardToRippled(request, clientIP_, yield), response_); - EXPECT_EQ(loadBalancer->forwardToRippled(request, clientIP_, yield), response_); + EXPECT_EQ(loadBalancer->forwardToRippled(request, clientIP_, false, yield), response_); + EXPECT_EQ(loadBalancer->forwardToRippled(request, clientIP_, false, yield), response_); }); } +TEST_F(LoadBalancerForwardToRippledTests, forwardingCacheDisabledOnLedgerClosedHookCalled) +{ + auto loadBalancer = makeLoadBalancer(); + EXPECT_NO_THROW(sourceFactory_.callbacksAt(0).onLedgerClosed()); +} + TEST_F(LoadBalancerForwardToRippledTests, onLedgerClosedHookInvalidatesCache) { configJson_.as_object()["forwarding_cache_timeout"] = 10.; @@ -542,16 +580,22 @@ TEST_F(LoadBalancerForwardToRippledTests, onLedgerClosedHookInvalidatesCache) auto const request = boost::json::object{{"command", "server_info"}}; - EXPECT_CALL(sourceFactory_.sourceAt(0), forwardToRippled(request, clientIP_, testing::_)) + EXPECT_CALL( + sourceFactory_.sourceAt(0), + forwardToRippled(request, clientIP_, LoadBalancer::USER_FORWARDING_X_USER_VALUE, testing::_) + ) .WillOnce(Return(response_)); - EXPECT_CALL(sourceFactory_.sourceAt(1), forwardToRippled(request, clientIP_, testing::_)) + EXPECT_CALL( + sourceFactory_.sourceAt(1), + forwardToRippled(request, clientIP_, LoadBalancer::USER_FORWARDING_X_USER_VALUE, testing::_) + ) .WillOnce(Return(boost::json::object{})); runSpawn([&](boost::asio::yield_context yield) { - EXPECT_EQ(loadBalancer->forwardToRippled(request, clientIP_, yield), response_); - EXPECT_EQ(loadBalancer->forwardToRippled(request, clientIP_, yield), response_); + EXPECT_EQ(loadBalancer->forwardToRippled(request, clientIP_, false, yield), response_); + EXPECT_EQ(loadBalancer->forwardToRippled(request, clientIP_, false, yield), response_); sourceFactory_.callbacksAt(0).onLedgerClosed(); - EXPECT_EQ(loadBalancer->forwardToRippled(request, clientIP_, yield), boost::json::object{}); + EXPECT_EQ(loadBalancer->forwardToRippled(request, clientIP_, false, yield), boost::json::object{}); }); } diff --git a/tests/unit/etl/SourceImplTests.cpp b/tests/unit/etl/SourceImplTests.cpp index a9c78ecd..f9535dd4 100644 --- a/tests/unit/etl/SourceImplTests.cpp +++ b/tests/unit/etl/SourceImplTests.cpp @@ -33,6 +33,7 @@ #include #include #include +#include #include #include @@ -67,7 +68,7 @@ struct ForwardingSourceMock { MOCK_METHOD( ForwardToRippledReturnType, forwardToRippled, - (boost::json::object const&, ClientIpOpt const&, boost::asio::yield_context), + (boost::json::object const&, ClientIpOpt const&, std::string_view, boost::asio::yield_context), (const) ); }; @@ -175,12 +176,14 @@ TEST_F(SourceImplTest, forwardToRippled) { boost::json::object const request = {{"some_key", "some_value"}}; std::optional const clientIp = "some_client_ip"; + std::string_view xUserValue = "some_user"; - EXPECT_CALL(forwardingSourceMock_, forwardToRippled(request, clientIp, testing::_)).WillOnce(Return(request)); + EXPECT_CALL(forwardingSourceMock_, forwardToRippled(request, clientIp, xUserValue, testing::_)) + .WillOnce(Return(request)); boost::asio::io_context ioContext; boost::asio::spawn(ioContext, [&](boost::asio::yield_context yield) { - auto const response = source_.forwardToRippled(request, clientIp, yield); + auto const response = source_.forwardToRippled(request, clientIp, xUserValue, yield); EXPECT_EQ(response, request); }); ioContext.run(); diff --git a/tests/unit/rpc/ForwardingProxyTests.cpp b/tests/unit/rpc/ForwardingProxyTests.cpp index 134bca86..e3952010 100644 --- a/tests/unit/rpc/ForwardingProxyTests.cpp +++ b/tests/unit/rpc/ForwardingProxyTests.cpp @@ -333,15 +333,14 @@ TEST_F(RPCForwardingProxyTest, ForwardCallsBalancerWithCorrectParams) auto const params = json::parse(R"({"test": true})"); auto const forwarded = json::parse(R"({"test": true, "command": "submit"})"); - ON_CALL(*rawBalancerPtr, forwardToRippled).WillByDefault(Return(std::make_optional())); - EXPECT_CALL(*rawBalancerPtr, forwardToRippled(forwarded.as_object(), std::make_optional(CLIENT_IP), _)) - .Times(1); + EXPECT_CALL( + *rawBalancerPtr, forwardToRippled(forwarded.as_object(), std::make_optional(CLIENT_IP), true, _) + ) + .WillOnce(Return(std::make_optional())); - ON_CALL(*rawHandlerProviderPtr, contains).WillByDefault(Return(true)); - EXPECT_CALL(*rawHandlerProviderPtr, contains(method)).Times(1); + EXPECT_CALL(*rawHandlerProviderPtr, contains(method)).WillOnce(Return(true)); - ON_CALL(counters, rpcForwarded).WillByDefault(Return()); - EXPECT_CALL(counters, rpcForwarded(method)).Times(1); + EXPECT_CALL(counters, rpcForwarded(method)); runSpawn([&](auto yield) { auto const range = backend->fetchLedgerRange(); @@ -364,15 +363,14 @@ TEST_F(RPCForwardingProxyTest, ForwardingFailYieldsErrorStatus) auto const params = json::parse(R"({"test": true})"); auto const forwarded = json::parse(R"({"test": true, "command": "submit"})"); - ON_CALL(*rawBalancerPtr, forwardToRippled).WillByDefault(Return(std::nullopt)); - EXPECT_CALL(*rawBalancerPtr, forwardToRippled(forwarded.as_object(), std::make_optional(CLIENT_IP), _)) - .Times(1); + EXPECT_CALL( + *rawBalancerPtr, forwardToRippled(forwarded.as_object(), std::make_optional(CLIENT_IP), true, _) + ) + .WillOnce(Return(std::nullopt)); - ON_CALL(*rawHandlerProviderPtr, contains).WillByDefault(Return(true)); - EXPECT_CALL(*rawHandlerProviderPtr, contains(method)).Times(1); + EXPECT_CALL(*rawHandlerProviderPtr, contains(method)).WillOnce(Return(true)); - ON_CALL(counters, rpcFailedToForward).WillByDefault(Return()); - EXPECT_CALL(counters, rpcFailedToForward(method)).Times(1); + EXPECT_CALL(counters, rpcFailedToForward(method)); runSpawn([&](auto yield) { auto const range = backend->fetchLedgerRange(); diff --git a/tests/unit/rpc/handlers/ServerInfoTests.cpp b/tests/unit/rpc/handlers/ServerInfoTests.cpp index 84002a0d..925980d7 100644 --- a/tests/unit/rpc/handlers/ServerInfoTests.cpp +++ b/tests/unit/rpc/handlers/ServerInfoTests.cpp @@ -191,7 +191,7 @@ TEST_F(RPCServerInfoHandlerTest, DefaultOutputIsPresent) auto const feeBlob = CreateLegacyFeeSettingBlob(1, 2, 3, 4, 0); EXPECT_CALL(*backend, doFetchLedgerObject).WillOnce(Return(feeBlob)); - EXPECT_CALL(*rawBalancerPtr, forwardToRippled(testing::_, testing::Eq(CLIENTIP), testing::_)) + EXPECT_CALL(*rawBalancerPtr, forwardToRippled(testing::_, testing::Eq(CLIENTIP), false, testing::_)) .WillOnce(Return(std::nullopt)); EXPECT_CALL(*rawCountersPtr, uptime).WillOnce(Return(std::chrono::seconds{1234})); @@ -228,7 +228,7 @@ TEST_F(RPCServerInfoHandlerTest, AmendmentBlockedIsPresentIfSet) auto const feeBlob = CreateLegacyFeeSettingBlob(1, 2, 3, 4, 0); EXPECT_CALL(*backend, doFetchLedgerObject).WillOnce(Return(feeBlob)); - EXPECT_CALL(*rawBalancerPtr, forwardToRippled(testing::_, testing::Eq(CLIENTIP), testing::_)) + EXPECT_CALL(*rawBalancerPtr, forwardToRippled(testing::_, testing::Eq(CLIENTIP), false, testing::_)) .WillOnce(Return(std::nullopt)); EXPECT_CALL(*rawCountersPtr, uptime).WillOnce(Return(std::chrono::seconds{1234})); @@ -263,7 +263,7 @@ TEST_F(RPCServerInfoHandlerTest, CorruptionDetectedIsPresentIfSet) auto const feeBlob = CreateLegacyFeeSettingBlob(1, 2, 3, 4, 0); EXPECT_CALL(*backend, doFetchLedgerObject).WillOnce(Return(feeBlob)); - EXPECT_CALL(*rawBalancerPtr, forwardToRippled(testing::_, testing::Eq(CLIENTIP), testing::_)) + EXPECT_CALL(*rawBalancerPtr, forwardToRippled(testing::_, testing::Eq(CLIENTIP), false, testing::_)) .WillOnce(Return(std::nullopt)); EXPECT_CALL(*rawCountersPtr, uptime).WillOnce(Return(std::chrono::seconds{1234})); @@ -297,7 +297,7 @@ TEST_F(RPCServerInfoHandlerTest, CacheReportsEnabledFlagCorrectly) auto const feeBlob = CreateLegacyFeeSettingBlob(1, 2, 3, 4, 0); EXPECT_CALL(*backend, doFetchLedgerObject).Times(2).WillRepeatedly(Return(feeBlob)); - EXPECT_CALL(*rawBalancerPtr, forwardToRippled(testing::_, testing::Eq(CLIENTIP), testing::_)) + EXPECT_CALL(*rawBalancerPtr, forwardToRippled(testing::_, testing::Eq(CLIENTIP), false, testing::_)) .Times(2) .WillRepeatedly(Return(std::nullopt));