diff --git a/src/web/ProxyIpResolver.cpp b/src/web/ProxyIpResolver.cpp index d0cf97b8c..04b82a294 100644 --- a/src/web/ProxyIpResolver.cpp +++ b/src/web/ProxyIpResolver.cpp @@ -8,6 +8,7 @@ #include +#include #include #include #include @@ -49,20 +50,20 @@ ProxyIpResolver::fromConfig(util::config::ClioConfigDefinition const& config) return ProxyIpResolver{std::move(ips), std::move(tokens)}; } -std::string +std::optional ProxyIpResolver::resolveClientIp(std::string const& connectionIp, HttpHeaders const& headers) const { if (proxyIps_.contains(connectionIp)) { - return extractClientIp(headers).value_or(connectionIp); + return extractClientIp(headers); } if (auto it = headers.find(kPROXY_TOKEN_HEADER); it != headers.end()) { auto const tokenHash = util::sha256sum(it->value()); if (proxyTokens_.contains(tokenHash)) { - return extractClientIp(headers).value_or(connectionIp); + return extractClientIp(headers); } } - return connectionIp; + return std::nullopt; } std::optional @@ -78,14 +79,17 @@ ProxyIpResolver::extractClientIp(HttpHeaders const& headers) auto const headerValue = util::toLower(it->value()); static constexpr std::string_view kFOR_PREFIX = "for="; - auto const startPos = headerValue.find(kFOR_PREFIX); + auto const startPos = headerValue.rfind(kFOR_PREFIX); if (startPos == std::string::npos) { return std::nullopt; } auto value = it->value().substr(startPos + kFOR_PREFIX.size()); - static constexpr char kDELIMITER = ';'; - auto const endPos = value.find(kDELIMITER); + static constexpr char kSECTION_DELIMITER = ';'; + static constexpr char kCHAIN_DELIMITER = ','; + auto const sectionEnd = value.find(kSECTION_DELIMITER); + auto const chainEnd = value.find(kCHAIN_DELIMITER); + auto const endPos = std::min(sectionEnd, chainEnd); auto const ip = value.substr(0, endPos); static constexpr auto kMIN_IP_LENGTH = 7; // minimum 3 dots + 4 digits diff --git a/src/web/ProxyIpResolver.hpp b/src/web/ProxyIpResolver.hpp index fc9effbbc..b5dc014af 100644 --- a/src/web/ProxyIpResolver.hpp +++ b/src/web/ProxyIpResolver.hpp @@ -59,16 +59,16 @@ public: * * If the connection IP is in the trusted proxy list, or if a valid proxy token is provided in * the headers, this method will attempt to extract the client's IP from the `Forwarded` header. - * Otherwise, it returns the connection IP. + * Otherwise, returns std::nullopt. * * @param connectionIp The IP address of the direct connection. * @param headers The HTTP request headers. - * @return The resolved client IP address as a string. + * @return The resolved client IP address if the connection is from a trusted proxy, otherwise + * std::nullopt. */ - std::string + std::optional resolveClientIp(std::string const& connectionIp, HttpHeaders const& headers) const; -private: /** * @brief Extracts the client IP from the `Forwarded` HTTP header. * diff --git a/src/web/impl/HttpBase.hpp b/src/web/impl/HttpBase.hpp index 9935063f9..04c128325 100644 --- a/src/web/impl/HttpBase.hpp +++ b/src/web/impl/HttpBase.hpp @@ -122,6 +122,7 @@ class HttpBase : public ConnectionBase { SendLambda sender_; std::shared_ptr adminVerification_; std::shared_ptr proxyIpResolver_; + bool isProxyConnection_ = false; protected: boost::beast::flat_buffer buffer_; @@ -222,14 +223,26 @@ public: if (ec) return httpFail(ec, "read"); - if (auto resolvedIp = proxyIpResolver_->resolveClientIp(clientIp_, req_); - resolvedIp != clientIp_) { + auto const updateClientIp = [&](std::string newIp) { + if (newIp == clientIp_) + return; LOG(log_.info()) << tag() - << "Detected a forwarded request from proxy. Proxy ip: " << clientIp_ - << ". Resolved client ip: " << resolvedIp; + << "Detected a forwarded request from proxy. Resolved client ip: " + << newIp; dosGuard_.get().decrement(clientIp_); - clientIp_ = std::move(resolvedIp); + clientIp_ = std::move(newIp); dosGuard_.get().increment(clientIp_); + }; + + if (isProxyConnection_) { + if (auto resolvedIp = ProxyIpResolver::extractClientIp(req_); resolvedIp.has_value()) + updateClientIp(std::move(*resolvedIp)); + } else if ( + auto resolvedIp = proxyIpResolver_->resolveClientIp(clientIp_, req_); + resolvedIp.has_value() + ) { + updateClientIp(std::move(*resolvedIp)); + isProxyConnection_ = true; } if (req_.method() == http::verb::get and req_.target() == "/health") diff --git a/src/web/ng/Connection.hpp b/src/web/ng/Connection.hpp index cf7f3693e..67f3066c2 100644 --- a/src/web/ng/Connection.hpp +++ b/src/web/ng/Connection.hpp @@ -26,6 +26,7 @@ class ConnectionMetadata : public util::Taggable { protected: std::string ip_; // client ip std::optional isAdmin_; + bool isProxyConnection_ = false; public: /** @@ -63,6 +64,26 @@ public: ip_ = std::move(newIp); } + /** + * @brief Mark this connection as coming through a trusted proxy. + */ + void + markAsProxyConnection() + { + isProxyConnection_ = true; + } + + /** + * @brief Whether this connection was identified as coming through a trusted proxy. + * + * @return true if the connection is a proxy connection. + */ + [[nodiscard]] bool + isProxyConnection() const + { + return isProxyConnection_; + } + /** * @brief Get whether the client is an admin. * diff --git a/src/web/ng/impl/ConnectionHandler.cpp b/src/web/ng/impl/ConnectionHandler.cpp index 8d34bec7e..9ac7ed056 100644 --- a/src/web/ng/impl/ConnectionHandler.cpp +++ b/src/web/ng/impl/ConnectionHandler.cpp @@ -402,14 +402,26 @@ ConnectionHandler::handleRequest( void ConnectionHandler::resolveClientIp(Connection& connection, Request const& request) const { - if (auto resolvedClientIp = - proxyIpResolver_.resolveClientIp(connection.ip(), request.httpHeaders()); - resolvedClientIp != connection.ip()) { + auto const updateIp = [&](std::string newIp) { + if (newIp == connection.ip()) + return; LOG(log_.info()) << connection.tag() - << "Detected a forwarded request from proxy. Proxy ip: " << connection.ip() - << ". Resolved client ip: " << resolvedClientIp; - onIpChangeHook_(connection.ip(), resolvedClientIp); - connection.setIp(std::move(resolvedClientIp)); + << "Detected a forwarded request from proxy. Resolved client ip: " + << newIp; + onIpChangeHook_(connection.ip(), newIp); + connection.setIp(std::move(newIp)); + }; + + if (connection.isProxyConnection()) { + if (auto resolvedIp = ProxyIpResolver::extractClientIp(request.httpHeaders()); + resolvedIp.has_value()) + updateIp(std::move(*resolvedIp)); + } else if ( + auto resolvedIp = proxyIpResolver_.resolveClientIp(connection.ip(), request.httpHeaders()); + resolvedIp.has_value() + ) { + updateIp(std::move(*resolvedIp)); + connection.markAsProxyConnection(); } } diff --git a/tests/unit/web/ProxyIpResolverTests.cpp b/tests/unit/web/ProxyIpResolverTests.cpp index 808fc58d0..96c5d91d0 100644 --- a/tests/unit/web/ProxyIpResolverTests.cpp +++ b/tests/unit/web/ProxyIpResolverTests.cpp @@ -11,6 +11,7 @@ #include #include +#include #include #include #include @@ -26,7 +27,7 @@ struct ProxyIpResolverTestParams { std::unordered_set proxyTokens; std::vector> headers; std::string connectionIp; - std::string expectedIp; + std::optional expectedIp; }; class ProxyIpResolverTest : public ::testing::TestWithParam {}; @@ -61,11 +62,11 @@ TEST_F(ProxyIpResolverTest, FromConfig) auto const proxyIpResolver = ProxyIpResolver::fromConfig(config); ProxyIpResolver::HttpHeaders headers; - EXPECT_EQ(proxyIpResolver.resolveClientIp(clientIp, headers), clientIp); - EXPECT_EQ(proxyIpResolver.resolveClientIp(proxyIp, headers), proxyIp); + EXPECT_EQ(proxyIpResolver.resolveClientIp(clientIp, headers), std::nullopt); + EXPECT_EQ(proxyIpResolver.resolveClientIp(proxyIp, headers), std::nullopt); headers.set(boost::beast::http::field::forwarded, fmt::format("for={}", clientIp)); - EXPECT_EQ(proxyIpResolver.resolveClientIp(clientIp, headers), clientIp); + EXPECT_EQ(proxyIpResolver.resolveClientIp(clientIp, headers), std::nullopt); EXPECT_EQ(proxyIpResolver.resolveClientIp(proxyIp, headers), clientIp); headers.set(ProxyIpResolver::kPROXY_TOKEN_HEADER, proxyToken); @@ -95,7 +96,7 @@ INSTANTIATE_TEST_SUITE_P( .proxyTokens = {}, .headers = {}, .connectionIp = "1.2.3.4", - .expectedIp = "1.2.3.4" + .expectedIp = std::nullopt }, ProxyIpResolverTestParams{ .testName = "TrustedProxyIpWithForwardedHeader", @@ -111,7 +112,7 @@ INSTANTIATE_TEST_SUITE_P( .proxyTokens = {}, .headers = {}, .connectionIp = "5.6.7.8", - .expectedIp = "5.6.7.8" + .expectedIp = std::nullopt }, ProxyIpResolverTestParams{ .testName = "UntrustedProxyIpWithForwardedHeader", @@ -119,7 +120,7 @@ INSTANTIATE_TEST_SUITE_P( .proxyTokens = {}, .headers = {{std::string(http::to_string(http::field::forwarded)), "for=1.2.3.4"}}, .connectionIp = "5.6.7.8", - .expectedIp = "5.6.7.8" + .expectedIp = std::nullopt }, ProxyIpResolverTestParams{ .testName = "TrustedProxyTokenWithForwardedHeader", @@ -137,7 +138,7 @@ INSTANTIATE_TEST_SUITE_P( .proxyTokens = {"test_token"}, .headers = {{std::string(ProxyIpResolver::kPROXY_TOKEN_HEADER), "test_token"}}, .connectionIp = "5.6.7.8", - .expectedIp = "5.6.7.8" + .expectedIp = std::nullopt }, ProxyIpResolverTestParams{ .testName = "UntrustedProxyTokenWithForwardedHeader", @@ -147,7 +148,7 @@ INSTANTIATE_TEST_SUITE_P( {{std::string(ProxyIpResolver::kPROXY_TOKEN_HEADER), "test_token"}, {std::string(http::to_string(http::field::forwarded)), "for=1.2.3.4"}}, .connectionIp = "5.6.7.8", - .expectedIp = "5.6.7.8" + .expectedIp = std::nullopt }, ProxyIpResolverTestParams{ .testName = "ForwardedHeaderWithAdditionalFields", @@ -173,7 +174,7 @@ INSTANTIATE_TEST_SUITE_P( .proxyTokens = {}, .headers = {{std::string(http::to_string(http::field::forwarded)), "by=1.2.3.4"}}, .connectionIp = "5.6.7.8", - .expectedIp = "5.6.7.8" + .expectedIp = std::nullopt }, ProxyIpResolverTestParams{ .testName = "ForwardedHeaderWithIpInQuotes", @@ -190,7 +191,27 @@ INSTANTIATE_TEST_SUITE_P( .headers = {{std::string(http::to_string(http::field::forwarded)), "for=\";some_other_text"}}, .connectionIp = "5.6.7.8", - .expectedIp = "5.6.7.8" + .expectedIp = std::nullopt + }, + ProxyIpResolverTestParams{ + .testName = "ForwardedHeaderWithMultipleForValues", + .proxyIps = {"5.6.7.8"}, + .proxyTokens = {}, + .headers = + {{std::string(http::to_string(http::field::forwarded)), + "for=1.2.3.4, for=9.10.11.12"}}, + .connectionIp = "5.6.7.8", + .expectedIp = "9.10.11.12" + }, + ProxyIpResolverTestParams{ + .testName = "ForwardedHeaderWithMultipleForValuesAndSectionDelimiters", + .proxyIps = {"5.6.7.8"}, + .proxyTokens = {}, + .headers = + {{std::string(http::to_string(http::field::forwarded)), + "for=1.2.3.4; proto=http, for=9.10.11.12; proto=https"}}, + .connectionIp = "5.6.7.8", + .expectedIp = "9.10.11.12" } ), tests::util::kNAME_GENERATOR diff --git a/tests/unit/web/ng/impl/ConnectionHandlerTests.cpp b/tests/unit/web/ng/impl/ConnectionHandlerTests.cpp index ea61344d9..31d8f5b4f 100644 --- a/tests/unit/web/ng/impl/ConnectionHandlerTests.cpp +++ b/tests/unit/web/ng/impl/ConnectionHandlerTests.cpp @@ -593,6 +593,110 @@ TEST_F(ConnectionHandlerSequentialProcessingTest, OnIpChangeHookCalledWhenSentFr ) mutable { connectionHandler.processConnection(std::move(c), yield); }); } +TEST_F(ConnectionHandlerSequentialProcessingTest, ProxyConnection_SameClientReuses_HookCalledOnce) +{ + std::string const target = "/some/target"; + testing::StrictMock> + getHandlerMock; + connectionHandler.onGet(target, getHandlerMock.AsStdFunction()); + + StrictMockHttpConnectionPtr mockProxyConnection = std::make_unique( + proxyIp, boost::beast::flat_buffer{}, tagDecoratorFactory + ); + + auto request = http::request{http::verb::get, target, 11, ""}; + request.set(http::field::forwarded, fmt::format("for={}", clientIp)); + + EXPECT_CALL(*mockProxyConnection, wasUpgraded).WillOnce(Return(false)); + EXPECT_CALL(*mockProxyConnection, receive) + .WillOnce(Return(makeRequest(request))) + .WillOnce(Return(makeRequest(request))) + .WillOnce(Return(makeError(http::error::end_of_stream))); + + EXPECT_CALL(onIpChangeMock, Call(proxyIp, clientIp)); + + EXPECT_CALL(getHandlerMock, Call) + .Times(2) + .WillRepeatedly([](Request const& req, auto&&, auto&&, auto&&) { + return Response(http::status::ok, "ok", req); + }); + + EXPECT_CALL(*mockProxyConnection, send) + .Times(2) + .WillRepeatedly(Return(std::expected{})); + + EXPECT_CALL(onDisconnectMock, Call) + .WillOnce([this, ptr = mockProxyConnection.get()](Connection const& c) { + EXPECT_EQ(&c, ptr); + EXPECT_EQ(c.ip(), clientIp); + }); + + runSpawn([this, c = std::move(mockProxyConnection)](boost::asio::yield_context yield) mutable { + connectionHandler.processConnection(std::move(c), yield); + }); +} + +TEST_F( + ConnectionHandlerSequentialProcessingTest, + ProxyConnection_DifferentClientReuses_HookCalledForEachIpChange +) +{ + std::string const target = "/some/target"; + std::string const anotherClientIp = "9.10.11.12"; + testing::StrictMock> + getHandlerMock; + connectionHandler.onGet(target, getHandlerMock.AsStdFunction()); + + StrictMockHttpConnectionPtr mockProxyConnection = std::make_unique( + proxyIp, boost::beast::flat_buffer{}, tagDecoratorFactory + ); + + auto request1 = http::request{http::verb::get, target, 11, ""}; + request1.set(http::field::forwarded, fmt::format("for={}", clientIp)); + + auto request2 = http::request{http::verb::get, target, 11, ""}; + request2.set(http::field::forwarded, fmt::format("for={}", anotherClientIp)); + + EXPECT_CALL(*mockProxyConnection, wasUpgraded).WillOnce(Return(false)); + EXPECT_CALL(*mockProxyConnection, receive) + .WillOnce(Return(makeRequest(request1))) + .WillOnce(Return(makeRequest(request2))) + .WillOnce(Return(makeError(http::error::end_of_stream))); + + EXPECT_CALL(onIpChangeMock, Call(proxyIp, clientIp)); + EXPECT_CALL(onIpChangeMock, Call(clientIp, anotherClientIp)); + + EXPECT_CALL(getHandlerMock, Call) + .Times(2) + .WillRepeatedly([](Request const& req, auto&&, auto&&, auto&&) { + return Response(http::status::ok, "ok", req); + }); + + EXPECT_CALL(*mockProxyConnection, send) + .Times(2) + .WillRepeatedly(Return(std::expected{})); + + EXPECT_CALL(onDisconnectMock, Call) + .WillOnce([anotherClientIp, ptr = mockProxyConnection.get()](Connection const& c) { + EXPECT_EQ(&c, ptr); + EXPECT_EQ(c.ip(), anotherClientIp); + }); + + runSpawn([this, c = std::move(mockProxyConnection)](boost::asio::yield_context yield) mutable { + connectionHandler.processConnection(std::move(c), yield); + }); +} + TEST_F(ConnectionHandlerSequentialProcessingTest, Stop) { testing::StrictMock( + proxyIp, boost::beast::flat_buffer{}, tagDecoratorFactory + ); + + headers.set(http::field::forwarded, fmt::format("for={}", clientIp)); + + EXPECT_CALL(*mockProxyConnection, wasUpgraded).WillOnce(Return(true)); + EXPECT_CALL(*mockProxyConnection, receive) + .WillOnce(Return(makeRequest("msg", headers))) + .WillOnce(Return(makeRequest("msg", headers))) + .WillOnce(Return(makeError(websocket::error::closed))); + + EXPECT_CALL(onIpChangeMock, Call(proxyIp, clientIp)); + + EXPECT_CALL(*mockProxyConnection, send) + .Times(2) + .WillRepeatedly(Return(std::expected{})); + + EXPECT_CALL(onDisconnectMock, Call) + .WillOnce([this, ptr = mockProxyConnection.get()](Connection const& c) { + EXPECT_EQ(&c, ptr); + EXPECT_EQ(c.ip(), clientIp); + }); + + runSpawn([this, c = std::move(mockProxyConnection)](boost::asio::yield_context yield) mutable { + connectionHandler.processConnection(std::move(c), yield); + }); +} + +TEST_F( + ConnectionHandlerParallelProcessingTest, + ProxyConnection_DifferentClientReuses_HookCalledForEachIpChange +) +{ + std::string const anotherClientIp = "9.10.11.12"; + connectionHandler.onWs([](Request const& req, auto&&, auto&&, auto&&) { + return Response(http::status::ok, "ok", req); + }); + + StrictMockWsConnectionPtr mockProxyConnection = std::make_unique( + proxyIp, boost::beast::flat_buffer{}, tagDecoratorFactory + ); + + Request::HttpHeaders headers1; + headers1.set(http::field::forwarded, fmt::format("for={}", clientIp)); + + Request::HttpHeaders headers2; + headers2.set(http::field::forwarded, fmt::format("for={}", anotherClientIp)); + + EXPECT_CALL(*mockProxyConnection, wasUpgraded).WillOnce(Return(true)); + EXPECT_CALL(*mockProxyConnection, receive) + .WillOnce(Return(makeRequest("msg", headers1))) + .WillOnce(Return(makeRequest("msg", headers2))) + .WillOnce(Return(makeError(websocket::error::closed))); + + EXPECT_CALL(onIpChangeMock, Call(proxyIp, clientIp)); + EXPECT_CALL(onIpChangeMock, Call(clientIp, anotherClientIp)); + + EXPECT_CALL(*mockProxyConnection, send) + .Times(2) + .WillRepeatedly(Return(std::expected{})); + + EXPECT_CALL(onDisconnectMock, Call) + .WillOnce([anotherClientIp, ptr = mockProxyConnection.get()](Connection const& c) { + EXPECT_EQ(&c, ptr); + EXPECT_EQ(c.ip(), anotherClientIp); + }); + + runSpawn([this, c = std::move(mockProxyConnection)](boost::asio::yield_context yield) mutable { + connectionHandler.processConnection(std::move(c), yield); + }); +} + TEST_F(ConnectionHandlerParallelProcessingTest, Receive_Handle_Send_Loop) { testing::StrictMock