Add option to set X-User header value for forwarded requests (#1425)

Fixes #1422.
This commit is contained in:
Sergey Kuznetsov
2024-06-11 17:59:10 +01:00
committed by GitHub
parent 9d3b4f0313
commit 56ab943be5
20 changed files with 205 additions and 62 deletions

View File

@@ -40,7 +40,7 @@ struct MockLoadBalancer {
MOCK_METHOD(
std::optional<boost::json::object>,
forwardToRippled,
(boost::json::object const&, std::optional<std::string> const&, boost::asio::yield_context),
(boost::json::object const&, std::optional<std::string> const&, bool, boost::asio::yield_context),
(const)
);
};

View File

@@ -19,7 +19,7 @@
#pragma once
#include "data/BackendInterface.hpp"
#include "etl/ETLHelpers.hpp"
#include "etl/NetworkValidatedLedgersInterface.hpp"
#include "etl/Source.hpp"
#include "feed/SubscriptionManagerInterface.hpp"
#include "util/config/Config.hpp"
@@ -39,7 +39,7 @@
#include <memory>
#include <optional>
#include <string>
#include <unordered_map>
#include <string_view>
#include <utility>
#include <vector>
@@ -60,7 +60,7 @@ struct MockSource : etl::SourceBase {
MOCK_METHOD(
std::optional<boost::json::object>,
forwardToRippled,
(boost::json::object const&, std::optional<std::string> const&, boost::asio::yield_context),
(boost::json::object const&, std::optional<std::string> const&, std::string_view, boost::asio::yield_context),
(const, override)
);
};
@@ -129,10 +129,11 @@ public:
forwardToRippled(
boost::json::object const& request,
std::optional<std::string> const& forwardToRippledClientIp,
std::string_view xUserValue,
boost::asio::yield_context yield
) const override
{
return mock_->forwardToRippled(request, forwardToRippledClientIp, yield);
return mock_->forwardToRippled(request, forwardToRippledClientIp, xUserValue, yield);
}
};

View File

@@ -32,20 +32,30 @@
#include <boost/beast/core/flat_buffer.hpp>
#include <boost/beast/core/role.hpp>
#include <boost/beast/core/tcp_stream.hpp>
#include <boost/beast/http/field.hpp>
#include <boost/beast/http/message.hpp>
#include <boost/beast/http/string_body.hpp>
#include <boost/beast/websocket/error.hpp>
#include <boost/beast/websocket/rfc6455.hpp>
#include <boost/beast/websocket/stream_base.hpp>
#include <gtest/gtest.h>
#include <algorithm>
#include <expected>
#include <iterator>
#include <optional>
#include <string>
#include <utility>
#include <vector>
namespace asio = boost::asio;
namespace websocket = boost::beast::websocket;
TestWsConnection::TestWsConnection(websocket::stream<boost::beast::tcp_stream> wsStream) : ws_(std::move(wsStream))
TestWsConnection::TestWsConnection(
websocket::stream<boost::beast::tcp_stream> wsStream,
std::vector<util::requests::HttpHeader> headers
)
: ws_(std::move(wsStream)), headers_(std::move(headers))
{
}
@@ -83,6 +93,12 @@ TestWsConnection::close(boost::asio::yield_context yield)
return std::nullopt;
}
std::vector<util::requests::HttpHeader> const&
TestWsConnection::headers() const
{
return headers_;
}
TestWsServer::TestWsServer(asio::io_context& context, std::string const& host, int port) : acceptor_(context)
{
auto endpoint = asio::ip::tcp::endpoint(boost::asio::ip::make_address(host), port);
@@ -102,13 +118,28 @@ TestWsServer::acceptConnection(asio::yield_context yield)
if (errorCode)
return std::unexpected{util::requests::RequestError{"Accept error", errorCode}};
boost::beast::flat_buffer buffer;
boost::beast::http::request<boost::beast::http::string_body> request;
boost::beast::http::async_read(socket, buffer, request, yield[errorCode]);
if (errorCode)
return std::unexpected{util::requests::RequestError{"Read error", errorCode}};
std::vector<util::requests::HttpHeader> headers;
std::transform(request.begin(), request.end(), std::back_inserter(headers), [](auto const& header) {
if (header.name() == boost::beast::http::field::unknown)
return util::requests::HttpHeader{header.name_string(), header.value()};
return util::requests::HttpHeader{header.name(), header.value()};
});
if (not boost::beast::websocket::is_upgrade(request))
return std::unexpected{util::requests::RequestError{"Not a websocket request"}};
boost::beast::websocket::stream<boost::beast::tcp_stream> ws(std::move(socket));
ws.set_option(websocket::stream_base::timeout::suggested(boost::beast::role_type::server));
ws.async_accept(yield[errorCode]);
ws.async_accept(request, yield[errorCode]);
if (errorCode)
return std::unexpected{util::requests::RequestError{"Handshake error", errorCode}};
return TestWsConnection(std::move(ws));
return TestWsConnection(std::move(ws), std::move(headers));
}
void

View File

@@ -31,15 +31,20 @@
#include <functional>
#include <optional>
#include <string>
#include <vector>
class TestWsConnection {
boost::beast::websocket::stream<boost::beast::tcp_stream> ws_;
std::vector<util::requests::HttpHeader> headers_;
public:
using SendCallback = std::function<void()>;
using ReceiveCallback = std::function<void(std::string)>;
TestWsConnection(boost::beast::websocket::stream<boost::beast::tcp_stream> wsStream);
TestWsConnection(
boost::beast::websocket::stream<boost::beast::tcp_stream> wsStream,
std::vector<util::requests::HttpHeader> headers
);
// returns error message if error occurs
std::optional<std::string>
@@ -51,6 +56,9 @@ public:
std::optional<std::string>
close(boost::asio::yield_context yield);
std::vector<util::requests::HttpHeader> const&
headers() const;
};
class TestWsServer {

View File

@@ -27,6 +27,7 @@
#include <boost/json/serialize.hpp>
#include <gtest/gtest.h>
#include <algorithm>
#include <chrono>
#include <optional>
#include <string>
@@ -42,7 +43,7 @@ struct ForwardingSourceTests : SyncAsioContextTest {
TEST_F(ForwardingSourceTests, ConnectionFailed)
{
runSpawn([&](boost::asio::yield_context yield) {
auto result = forwardingSource.forwardToRippled({}, {}, yield);
auto result = forwardingSource.forwardToRippled({}, {}, {}, yield);
EXPECT_FALSE(result);
});
}
@@ -59,11 +60,34 @@ struct ForwardingSourceOperationsTests : ForwardingSourceTests {
[&]() { ASSERT_FALSE(failedConnection); }();
auto connection = server_.acceptConnection(yield);
[&]() { ASSERT_TRUE(connection); }();
[&]() { ASSERT_TRUE(connection) << connection.error().message(); }();
return std::move(connection).value();
}
};
TEST_F(ForwardingSourceOperationsTests, XUserHeader)
{
std::string const xUserValue = "some_user";
boost::asio::spawn(ctx, [&](boost::asio::yield_context yield) {
auto connection = serverConnection(yield);
auto headers = connection.headers();
ASSERT_FALSE(headers.empty());
auto it = std::find_if(headers.begin(), headers.end(), [](auto const& header) {
return std::holds_alternative<std::string>(header.name) && std::get<std::string>(header.name) == "X-User";
});
ASSERT_FALSE(it == headers.end());
EXPECT_EQ(std::get<std::string>(it->name), "X-User");
EXPECT_EQ(it->value, xUserValue);
connection.close(yield);
});
runSpawn([&](boost::asio::yield_context yield) {
auto result =
forwardingSource.forwardToRippled(boost::json::parse(message_).as_object(), {}, xUserValue, yield);
EXPECT_FALSE(result);
});
}
TEST_F(ForwardingSourceOperationsTests, ReadFailed)
{
boost::asio::spawn(ctx, [&](boost::asio::yield_context yield) {
@@ -72,7 +96,7 @@ TEST_F(ForwardingSourceOperationsTests, ReadFailed)
});
runSpawn([&](boost::asio::yield_context yield) {
auto result = forwardingSource.forwardToRippled(boost::json::parse(message_).as_object(), {}, yield);
auto result = forwardingSource.forwardToRippled(boost::json::parse(message_).as_object(), {}, {}, yield);
EXPECT_FALSE(result);
});
}
@@ -93,7 +117,7 @@ TEST_F(ForwardingSourceOperationsTests, ParseFailed)
});
runSpawn([&](boost::asio::yield_context yield) {
auto result = forwardingSource.forwardToRippled(boost::json::parse(message_).as_object(), {}, yield);
auto result = forwardingSource.forwardToRippled(boost::json::parse(message_).as_object(), {}, {}, yield);
EXPECT_FALSE(result);
});
}
@@ -115,7 +139,7 @@ TEST_F(ForwardingSourceOperationsTests, GotNotAnObject)
});
runSpawn([&](boost::asio::yield_context yield) {
auto result = forwardingSource.forwardToRippled(boost::json::parse(message_).as_object(), {}, yield);
auto result = forwardingSource.forwardToRippled(boost::json::parse(message_).as_object(), {}, {}, yield);
EXPECT_FALSE(result);
});
}
@@ -134,7 +158,7 @@ TEST_F(ForwardingSourceOperationsTests, Success)
});
runSpawn([&](boost::asio::yield_context yield) {
auto result = forwardingSource.forwardToRippled(boost::json::parse(message_).as_object(), "some_ip", yield);
auto result = forwardingSource.forwardToRippled(boost::json::parse(message_).as_object(), "some_ip", {}, yield);
[&]() { ASSERT_TRUE(result); }();
auto expectedReply = reply_;
expectedReply["forwarded"] = true;

View File

@@ -485,37 +485,66 @@ struct LoadBalancerForwardToRippledTests : LoadBalancerConstructorTests, SyncAsi
TEST_F(LoadBalancerForwardToRippledTests, forward)
{
auto loadBalancer = makeLoadBalancer();
EXPECT_CALL(sourceFactory_.sourceAt(0), forwardToRippled(request_, clientIP_, testing::_))
EXPECT_CALL(
sourceFactory_.sourceAt(0),
forwardToRippled(request_, clientIP_, LoadBalancer::ADMIN_FORWARDING_X_USER_VALUE, testing::_)
)
.WillOnce(Return(response_));
runSpawn([&](boost::asio::yield_context yield) {
EXPECT_EQ(loadBalancer->forwardToRippled(request_, clientIP_, yield), response_);
EXPECT_EQ(loadBalancer->forwardToRippled(request_, clientIP_, true, yield), response_);
});
}
TEST_F(LoadBalancerForwardToRippledTests, forwardWithXUserHeader)
{
auto loadBalancer = makeLoadBalancer();
EXPECT_CALL(
sourceFactory_.sourceAt(0),
forwardToRippled(request_, clientIP_, LoadBalancer::USER_FORWARDING_X_USER_VALUE, testing::_)
)
.WillOnce(Return(response_));
runSpawn([&](boost::asio::yield_context yield) {
EXPECT_EQ(loadBalancer->forwardToRippled(request_, clientIP_, false, yield), response_);
});
}
TEST_F(LoadBalancerForwardToRippledTests, source0Fails)
{
auto loadBalancer = makeLoadBalancer();
EXPECT_CALL(sourceFactory_.sourceAt(0), forwardToRippled(request_, clientIP_, testing::_))
EXPECT_CALL(
sourceFactory_.sourceAt(0),
forwardToRippled(request_, clientIP_, LoadBalancer::USER_FORWARDING_X_USER_VALUE, testing::_)
)
.WillOnce(Return(std::nullopt));
EXPECT_CALL(sourceFactory_.sourceAt(1), forwardToRippled(request_, clientIP_, testing::_))
EXPECT_CALL(
sourceFactory_.sourceAt(1),
forwardToRippled(request_, clientIP_, LoadBalancer::USER_FORWARDING_X_USER_VALUE, testing::_)
)
.WillOnce(Return(response_));
runSpawn([&](boost::asio::yield_context yield) {
EXPECT_EQ(loadBalancer->forwardToRippled(request_, clientIP_, yield), response_);
EXPECT_EQ(loadBalancer->forwardToRippled(request_, clientIP_, false, yield), response_);
});
}
TEST_F(LoadBalancerForwardToRippledTests, bothSourcesFail)
{
auto loadBalancer = makeLoadBalancer();
EXPECT_CALL(sourceFactory_.sourceAt(0), forwardToRippled(request_, clientIP_, testing::_))
EXPECT_CALL(
sourceFactory_.sourceAt(0),
forwardToRippled(request_, clientIP_, LoadBalancer::USER_FORWARDING_X_USER_VALUE, testing::_)
)
.WillOnce(Return(std::nullopt));
EXPECT_CALL(sourceFactory_.sourceAt(1), forwardToRippled(request_, clientIP_, testing::_))
EXPECT_CALL(
sourceFactory_.sourceAt(1),
forwardToRippled(request_, clientIP_, LoadBalancer::USER_FORWARDING_X_USER_VALUE, testing::_)
)
.WillOnce(Return(std::nullopt));
runSpawn([&](boost::asio::yield_context yield) {
EXPECT_EQ(loadBalancer->forwardToRippled(request_, clientIP_, yield), std::nullopt);
EXPECT_EQ(loadBalancer->forwardToRippled(request_, clientIP_, false, yield), std::nullopt);
});
}
@@ -526,15 +555,24 @@ TEST_F(LoadBalancerForwardToRippledTests, forwardingCacheEnabled)
auto const request = boost::json::object{{"command", "server_info"}};
EXPECT_CALL(sourceFactory_.sourceAt(0), forwardToRippled(request, clientIP_, testing::_))
EXPECT_CALL(
sourceFactory_.sourceAt(0),
forwardToRippled(request, clientIP_, LoadBalancer::USER_FORWARDING_X_USER_VALUE, testing::_)
)
.WillOnce(Return(response_));
runSpawn([&](boost::asio::yield_context yield) {
EXPECT_EQ(loadBalancer->forwardToRippled(request, clientIP_, yield), response_);
EXPECT_EQ(loadBalancer->forwardToRippled(request, clientIP_, yield), response_);
EXPECT_EQ(loadBalancer->forwardToRippled(request, clientIP_, false, yield), response_);
EXPECT_EQ(loadBalancer->forwardToRippled(request, clientIP_, false, yield), response_);
});
}
TEST_F(LoadBalancerForwardToRippledTests, forwardingCacheDisabledOnLedgerClosedHookCalled)
{
auto loadBalancer = makeLoadBalancer();
EXPECT_NO_THROW(sourceFactory_.callbacksAt(0).onLedgerClosed());
}
TEST_F(LoadBalancerForwardToRippledTests, onLedgerClosedHookInvalidatesCache)
{
configJson_.as_object()["forwarding_cache_timeout"] = 10.;
@@ -542,16 +580,22 @@ TEST_F(LoadBalancerForwardToRippledTests, onLedgerClosedHookInvalidatesCache)
auto const request = boost::json::object{{"command", "server_info"}};
EXPECT_CALL(sourceFactory_.sourceAt(0), forwardToRippled(request, clientIP_, testing::_))
EXPECT_CALL(
sourceFactory_.sourceAt(0),
forwardToRippled(request, clientIP_, LoadBalancer::USER_FORWARDING_X_USER_VALUE, testing::_)
)
.WillOnce(Return(response_));
EXPECT_CALL(sourceFactory_.sourceAt(1), forwardToRippled(request, clientIP_, testing::_))
EXPECT_CALL(
sourceFactory_.sourceAt(1),
forwardToRippled(request, clientIP_, LoadBalancer::USER_FORWARDING_X_USER_VALUE, testing::_)
)
.WillOnce(Return(boost::json::object{}));
runSpawn([&](boost::asio::yield_context yield) {
EXPECT_EQ(loadBalancer->forwardToRippled(request, clientIP_, yield), response_);
EXPECT_EQ(loadBalancer->forwardToRippled(request, clientIP_, yield), response_);
EXPECT_EQ(loadBalancer->forwardToRippled(request, clientIP_, false, yield), response_);
EXPECT_EQ(loadBalancer->forwardToRippled(request, clientIP_, false, yield), response_);
sourceFactory_.callbacksAt(0).onLedgerClosed();
EXPECT_EQ(loadBalancer->forwardToRippled(request, clientIP_, yield), boost::json::object{});
EXPECT_EQ(loadBalancer->forwardToRippled(request, clientIP_, false, yield), boost::json::object{});
});
}

View File

@@ -33,6 +33,7 @@
#include <memory>
#include <optional>
#include <string>
#include <string_view>
#include <utility>
#include <vector>
@@ -67,7 +68,7 @@ struct ForwardingSourceMock {
MOCK_METHOD(
ForwardToRippledReturnType,
forwardToRippled,
(boost::json::object const&, ClientIpOpt const&, boost::asio::yield_context),
(boost::json::object const&, ClientIpOpt const&, std::string_view, boost::asio::yield_context),
(const)
);
};
@@ -175,12 +176,14 @@ TEST_F(SourceImplTest, forwardToRippled)
{
boost::json::object const request = {{"some_key", "some_value"}};
std::optional<std::string> const clientIp = "some_client_ip";
std::string_view xUserValue = "some_user";
EXPECT_CALL(forwardingSourceMock_, forwardToRippled(request, clientIp, testing::_)).WillOnce(Return(request));
EXPECT_CALL(forwardingSourceMock_, forwardToRippled(request, clientIp, xUserValue, testing::_))
.WillOnce(Return(request));
boost::asio::io_context ioContext;
boost::asio::spawn(ioContext, [&](boost::asio::yield_context yield) {
auto const response = source_.forwardToRippled(request, clientIp, yield);
auto const response = source_.forwardToRippled(request, clientIp, xUserValue, yield);
EXPECT_EQ(response, request);
});
ioContext.run();

View File

@@ -333,15 +333,14 @@ TEST_F(RPCForwardingProxyTest, ForwardCallsBalancerWithCorrectParams)
auto const params = json::parse(R"({"test": true})");
auto const forwarded = json::parse(R"({"test": true, "command": "submit"})");
ON_CALL(*rawBalancerPtr, forwardToRippled).WillByDefault(Return(std::make_optional<json::object>()));
EXPECT_CALL(*rawBalancerPtr, forwardToRippled(forwarded.as_object(), std::make_optional<std::string>(CLIENT_IP), _))
.Times(1);
EXPECT_CALL(
*rawBalancerPtr, forwardToRippled(forwarded.as_object(), std::make_optional<std::string>(CLIENT_IP), true, _)
)
.WillOnce(Return(std::make_optional<json::object>()));
ON_CALL(*rawHandlerProviderPtr, contains).WillByDefault(Return(true));
EXPECT_CALL(*rawHandlerProviderPtr, contains(method)).Times(1);
EXPECT_CALL(*rawHandlerProviderPtr, contains(method)).WillOnce(Return(true));
ON_CALL(counters, rpcForwarded).WillByDefault(Return());
EXPECT_CALL(counters, rpcForwarded(method)).Times(1);
EXPECT_CALL(counters, rpcForwarded(method));
runSpawn([&](auto yield) {
auto const range = backend->fetchLedgerRange();
@@ -364,15 +363,14 @@ TEST_F(RPCForwardingProxyTest, ForwardingFailYieldsErrorStatus)
auto const params = json::parse(R"({"test": true})");
auto const forwarded = json::parse(R"({"test": true, "command": "submit"})");
ON_CALL(*rawBalancerPtr, forwardToRippled).WillByDefault(Return(std::nullopt));
EXPECT_CALL(*rawBalancerPtr, forwardToRippled(forwarded.as_object(), std::make_optional<std::string>(CLIENT_IP), _))
.Times(1);
EXPECT_CALL(
*rawBalancerPtr, forwardToRippled(forwarded.as_object(), std::make_optional<std::string>(CLIENT_IP), true, _)
)
.WillOnce(Return(std::nullopt));
ON_CALL(*rawHandlerProviderPtr, contains).WillByDefault(Return(true));
EXPECT_CALL(*rawHandlerProviderPtr, contains(method)).Times(1);
EXPECT_CALL(*rawHandlerProviderPtr, contains(method)).WillOnce(Return(true));
ON_CALL(counters, rpcFailedToForward).WillByDefault(Return());
EXPECT_CALL(counters, rpcFailedToForward(method)).Times(1);
EXPECT_CALL(counters, rpcFailedToForward(method));
runSpawn([&](auto yield) {
auto const range = backend->fetchLedgerRange();

View File

@@ -191,7 +191,7 @@ TEST_F(RPCServerInfoHandlerTest, DefaultOutputIsPresent)
auto const feeBlob = CreateLegacyFeeSettingBlob(1, 2, 3, 4, 0);
EXPECT_CALL(*backend, doFetchLedgerObject).WillOnce(Return(feeBlob));
EXPECT_CALL(*rawBalancerPtr, forwardToRippled(testing::_, testing::Eq(CLIENTIP), testing::_))
EXPECT_CALL(*rawBalancerPtr, forwardToRippled(testing::_, testing::Eq(CLIENTIP), false, testing::_))
.WillOnce(Return(std::nullopt));
EXPECT_CALL(*rawCountersPtr, uptime).WillOnce(Return(std::chrono::seconds{1234}));
@@ -228,7 +228,7 @@ TEST_F(RPCServerInfoHandlerTest, AmendmentBlockedIsPresentIfSet)
auto const feeBlob = CreateLegacyFeeSettingBlob(1, 2, 3, 4, 0);
EXPECT_CALL(*backend, doFetchLedgerObject).WillOnce(Return(feeBlob));
EXPECT_CALL(*rawBalancerPtr, forwardToRippled(testing::_, testing::Eq(CLIENTIP), testing::_))
EXPECT_CALL(*rawBalancerPtr, forwardToRippled(testing::_, testing::Eq(CLIENTIP), false, testing::_))
.WillOnce(Return(std::nullopt));
EXPECT_CALL(*rawCountersPtr, uptime).WillOnce(Return(std::chrono::seconds{1234}));
@@ -263,7 +263,7 @@ TEST_F(RPCServerInfoHandlerTest, CorruptionDetectedIsPresentIfSet)
auto const feeBlob = CreateLegacyFeeSettingBlob(1, 2, 3, 4, 0);
EXPECT_CALL(*backend, doFetchLedgerObject).WillOnce(Return(feeBlob));
EXPECT_CALL(*rawBalancerPtr, forwardToRippled(testing::_, testing::Eq(CLIENTIP), testing::_))
EXPECT_CALL(*rawBalancerPtr, forwardToRippled(testing::_, testing::Eq(CLIENTIP), false, testing::_))
.WillOnce(Return(std::nullopt));
EXPECT_CALL(*rawCountersPtr, uptime).WillOnce(Return(std::chrono::seconds{1234}));
@@ -297,7 +297,7 @@ TEST_F(RPCServerInfoHandlerTest, CacheReportsEnabledFlagCorrectly)
auto const feeBlob = CreateLegacyFeeSettingBlob(1, 2, 3, 4, 0);
EXPECT_CALL(*backend, doFetchLedgerObject).Times(2).WillRepeatedly(Return(feeBlob));
EXPECT_CALL(*rawBalancerPtr, forwardToRippled(testing::_, testing::Eq(CLIENTIP), testing::_))
EXPECT_CALL(*rawBalancerPtr, forwardToRippled(testing::_, testing::Eq(CLIENTIP), false, testing::_))
.Times(2)
.WillRepeatedly(Return(std::nullopt));