feat: Proxy support (#2490)

Add client IP resolving support in case when there is a proxy in front
of Clio.
This commit is contained in:
Sergey Kuznetsov
2025-09-03 15:22:47 +01:00
committed by GitHub
parent 0a2930d861
commit 3a667f558c
39 changed files with 1042 additions and 125 deletions

View File

@@ -329,6 +329,22 @@ This document provides a list of all available Clio configuration properties in
- **Constraints**: The minimum value is `1`. The maximum value is `4294967295`.
- **Description**: Maximum queue size for sending subscription data to clients. This queue buffers data when a client is slow to receive it, ensuring delivery once the client is ready.
### server.proxy.ips.[]
- **Required**: True
- **Type**: string
- **Default value**: None
- **Constraints**: None
- **Description**: List of proxy ip addresses. When Clio receives a request from proxy it will use `Forwarded` value (if any) as client ip. When this option is used together with `server.proxy.tokens` Clio will identify proxy by ip or by token.
### server.proxy.tokens.[]
- **Required**: True
- **Type**: string
- **Default value**: None
- **Constraints**: None
- **Description**: List of tokens in identifying request as a request from proxy. Token should be provided in `X-Proxy-Token` header, e.g. `X-Proxy-Token: <very_secret_token>'. When Clio receives a request from proxy it will use 'Forwarded` value (if any) to get client ip. When this option is used together with 'server.proxy.ips' Clio will identify proxy by ip or by token.
### prometheus.enabled
- **Required**: True

View File

@@ -76,7 +76,11 @@
"parallel_requests_limit": 10, // Optional parameter, used only if "processing_strategy" is "parallel". It limits the number of requests for one client connection processed in parallel. Infinite if not specified.
// Max number of responses to queue up before sent successfully. If a client's waiting queue is too long, the server will close the connection.
"ws_max_sending_queue_size": 1500,
"__ng_web_server": false // Use ng web server. This is a temporary setting which will be deleted after switching to ng web server
"__ng_web_server": false, // Use ng web server. This is a temporary setting which will be deleted after switching to ng web server
"proxy": {
"ips": [],
"tokens": []
}
},
// Time in seconds for graceful shutdown. Defaults to 10 seconds. Not fully implemented yet.
"graceful_period": 10.0,

View File

@@ -178,7 +178,9 @@ ClioApplication::run(bool const useNgWebServer)
}
auto const adminVerifier = std::move(expectedAdminVerifier).value();
auto httpServer = web::ng::makeServer(config_, OnConnectCheck{dosGuard}, DisconnectHook{dosGuard}, ioc);
auto httpServer = web::ng::makeServer(
config_, OnConnectCheck{dosGuard}, IpChangeHook{dosGuard}, DisconnectHook{dosGuard}, ioc
);
if (not httpServer.has_value()) {
LOG(util::LogService::error()) << "Error creating web server: " << httpServer.error();

View File

@@ -33,6 +33,7 @@
#include <memory>
#include <optional>
#include <string>
#include <utility>
namespace app {
@@ -54,6 +55,17 @@ OnConnectCheck::operator()(web::ng::Connection const& connection)
return {};
}
IpChangeHook::IpChangeHook(web::dosguard::DOSGuardInterface& dosguard) : dosguard_(dosguard)
{
}
void
IpChangeHook::operator()(std::string const& oldIp, std::string const& newIp)
{
dosguard_.get().decrement(oldIp);
dosguard_.get().increment(newIp);
}
DisconnectHook::DisconnectHook(web::dosguard::DOSGuardInterface& dosguard) : dosguard_{dosguard}
{
}

View File

@@ -36,6 +36,7 @@
#include <exception>
#include <functional>
#include <memory>
#include <string>
#include <utility>
namespace app {
@@ -64,6 +65,31 @@ public:
operator()(web::ng::Connection const& connection);
};
/**
* @brief A function object that is called when the IP of a connection changes (usually if proxy detected).
* This is used to update the DOS guard.
*/
class IpChangeHook {
std::reference_wrapper<web::dosguard::DOSGuardInterface> dosguard_;
public:
/**
* @brief Construct a new IpChangeHook object.
*
* @param dosguard The DOS guard to use.
*/
IpChangeHook(web::dosguard::DOSGuardInterface& dosguard);
/**
* @brief The call of the function object.
*
* @param oldIp The old IP of the connection.
* @param newIp The new IP of the connection.
*/
void
operator()(std::string const& oldIp, std::string const& newIp);
};
/**
* @brief A function object to be called when a connection is disconnected.
*/

View File

@@ -24,6 +24,7 @@ target_sources(
requests/WsConnection.cpp
requests/impl/SslContext.cpp
ResponseExpirationCache.cpp
Shasum.cpp
SignalsHandler.cpp
StopHelper.cpp
StringHash.cpp

48
src/util/Shasum.cpp Normal file
View File

@@ -0,0 +1,48 @@
//------------------------------------------------------------------------------
/*
This file is part of clio: https://github.com/XRPLF/clio
Copyright (c) 2025, the clio developers.
Permission to use, copy, modify, and distribute this software for any
purpose with or without fee is hereby granted, provided that the above
copyright notice and this permission notice appear in all copies.
THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
*/
//==============================================================================
#include "util/Shasum.hpp"
#include <xrpl/basics/base_uint.h>
#include <xrpl/protocol/digest.h>
#include <cstring>
#include <string>
#include <string_view>
namespace util {
ripple::uint256
sha256sum(std::string_view s)
{
ripple::sha256_hasher hasher;
hasher(s.data(), s.size());
auto const hashData = static_cast<ripple::sha256_hasher::result_type>(hasher);
ripple::uint256 sha256;
std::memcpy(sha256.data(), hashData.data(), hashData.size());
return sha256;
}
std::string
sha256sumString(std::string_view s)
{
return ripple::to_string(sha256sum(s));
}
} // namespace util

46
src/util/Shasum.hpp Normal file
View File

@@ -0,0 +1,46 @@
//------------------------------------------------------------------------------
/*
This file is part of clio: https://github.com/XRPLF/clio
Copyright (c) 2025, the clio developers.
Permission to use, copy, modify, and distribute this software for any
purpose with or without fee is hereby granted, provided that the above
copyright notice and this permission notice appear in all copies.
THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
*/
//==============================================================================
#pragma once
#include <xrpl/basics/base_uint.h>
#include <string>
#include <string_view>
namespace util {
/**
* @brief Calculates the SHA256 sum of a string.
*
* @param s The string to hash.
* @return The SHA256 sum as a ripple::uint256.
*/
ripple::uint256
sha256sum(std::string_view s);
/**
* @brief Calculates the SHA256 sum of a string and returns it as a hex string.
*
* @param s The string to hash.
* @return The SHA256 sum as a hex string.
*/
std::string
sha256sumString(std::string_view s);
} // namespace util

View File

@@ -337,6 +337,8 @@ getClioConfig()
{"server.ws_max_sending_queue_size",
ConfigValue{ConfigType::Integer}.defaultValue(1500).withConstraint(gValidateUint32)},
{"server.__ng_web_server", ConfigValue{ConfigType::Boolean}.defaultValue(false)},
{"server.proxy.ips.[]", Array{ConfigValue{ConfigType::String}}},
{"server.proxy.tokens.[]", Array{ConfigValue{ConfigType::String}}},
{"prometheus.enabled", ConfigValue{ConfigType::Boolean}.defaultValue(true)},
{"prometheus.compress_reply", ConfigValue{ConfigType::Boolean}.defaultValue(true)},

View File

@@ -236,6 +236,16 @@ This document provides a list of all available Clio configuration properties in
KV{.key = "server.ws_max_sending_queue_size",
.value = "Maximum queue size for sending subscription data to clients. This queue buffers data when a "
"client is slow to receive it, ensuring delivery once the client is ready."},
KV{.key = "server.proxy.ips.[]",
.value = "List of proxy ip addresses. When Clio receives a request from proxy it will use "
"`Forwarded` value (if any) as client ip. When this option is used together with "
"`server.proxy.tokens` Clio will identify proxy by ip or by token."},
KV{.key = "server.proxy.tokens.[]",
.value = "List of tokens in identifying request as a request from proxy. Token should be provided in "
"`X-Proxy-Token` header, e.g. "
"`X-Proxy-Token: <very_secret_token>'. When Clio receives a request from proxy "
"it will use 'Forwarded` value (if any) to get client ip. When this option is used together with "
"'server.proxy.ips' Clio will identify proxy by ip or by token."},
KV{.key = "prometheus.enabled", .value = "Enables or disables Prometheus metrics."},
KV{.key = "prometheus.compress_reply", .value = "Enables or disables compression of Prometheus responses."},
KV{.key = "io_threads",

View File

@@ -20,6 +20,7 @@
#include "web/AdminVerificationStrategy.hpp"
#include "util/JsonUtils.hpp"
#include "util/Shasum.hpp"
#include "util/config/ConfigDefinition.hpp"
#include <boost/beast/http/field.hpp>
@@ -42,15 +43,8 @@ IPAdminVerificationStrategy::isAdmin(RequestHeader const&, std::string_view ip)
}
PasswordAdminVerificationStrategy::PasswordAdminVerificationStrategy(std::string const& password)
: passwordSha256_(util::toUpper(util::sha256sumString(password)))
{
ripple::sha256_hasher hasher;
hasher(password.data(), password.size());
auto const d = static_cast<ripple::sha256_hasher::result_type>(hasher);
ripple::uint256 sha256;
std::memcpy(sha256.data(), d.data(), d.size());
passwordSha256_ = ripple::to_string(sha256);
// make sure it's uppercase
passwordSha256_ = util::toUpper(std::move(passwordSha256_));
}
bool

View File

@@ -15,6 +15,7 @@ target_sources(
ng/Response.cpp
ng/Server.cpp
ng/SubscriptionContext.cpp
ProxyIpResolver.cpp
Resolver.cpp
SubscriptionContext.cpp
)

View File

@@ -22,6 +22,7 @@
#include "util/Taggable.hpp"
#include "web/AdminVerificationStrategy.hpp"
#include "web/PlainWsSession.hpp"
#include "web/ProxyIpResolver.hpp"
#include "web/dosguard/DOSGuardInterface.hpp"
#include "web/impl/HttpBase.hpp"
#include "web/interface/Concepts.hpp"
@@ -64,6 +65,7 @@ public:
* @param socket The socket. Ownership is transferred to HttpSession
* @param ip Client's IP address
* @param adminVerification The admin verification strategy to use
* @param proxyIpResolver The client ip resolver if a request was forwarded by a proxy
* @param tagFactory A factory that is used to generate tags to track requests and sessions
* @param dosGuard The denial of service guard to use
* @param handler The server handler to use
@@ -74,6 +76,7 @@ public:
tcp::socket&& socket,
std::string const& ip,
std::shared_ptr<AdminVerificationStrategy> const& adminVerification,
std::shared_ptr<ProxyIpResolver> proxyIpResolver,
std::reference_wrapper<util::TagDecoratorFactory const> tagFactory,
std::reference_wrapper<dosguard::DOSGuardInterface> dosGuard,
std::shared_ptr<HandlerType> const& handler,
@@ -84,6 +87,7 @@ public:
ip,
tagFactory,
adminVerification,
std::move(proxyIpResolver),
dosGuard,
handler,
std::move(buffer)
@@ -129,7 +133,7 @@ public:
{
std::make_shared<WsUpgrader<HandlerType>>(
std::move(stream_),
this->clientIp,
this->clientIp_,
tagFactory_,
this->dosGuard_,
this->handler_,

119
src/web/ProxyIpResolver.cpp Normal file
View File

@@ -0,0 +1,119 @@
/*
This file is part of clio: https://github.com/XRPLF/clio
Copyright (c) 2025, the clio developers.
Permission to use, copy, modify, and distribute this software for any
purpose with or without fee is hereby granted, provided that the above
copyright notice and this permission notice appear in all copies.
THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
*/
//==============================================================================
#include "web/ProxyIpResolver.hpp"
#include "util/JsonUtils.hpp"
#include "util/Shasum.hpp"
#include "util/config/ArrayView.hpp"
#include "util/config/ConfigDefinition.hpp"
#include "util/config/ValueView.hpp"
#include <boost/beast/http/field.hpp>
#include <boost/beast/http/message.hpp>
#include <boost/beast/http/string_body.hpp>
#include <xrpl/basics/base_uint.h>
#include <optional>
#include <string>
#include <string_view>
#include <unordered_set>
#include <utility>
namespace web {
ProxyIpResolver::ProxyIpResolver(std::unordered_set<std::string> proxyIps, std::unordered_set<std::string> proxyTokens)
: proxyIps_(std::move(proxyIps))
{
proxyTokens_.reserve(proxyTokens.size());
for (auto const& t : proxyTokens) {
proxyTokens_.insert(util::sha256sum(t));
}
}
ProxyIpResolver
ProxyIpResolver::fromConfig(util::config::ClioConfigDefinition const& config)
{
using util::config::ValueView;
std::unordered_set<std::string> ips;
auto const ipsFromConfig = config.getArray("server.proxy.ips");
for (auto it = ipsFromConfig.begin<ValueView>(); it != ipsFromConfig.end<ValueView>(); ++it) {
ips.insert((*it).asString());
}
std::unordered_set<std::string> tokens;
auto const tokensFromConfig = config.getArray("server.proxy.tokens");
for (auto it = tokensFromConfig.begin<ValueView>(); it != tokensFromConfig.end<ValueView>(); ++it) {
tokens.insert((*it).asString());
}
return ProxyIpResolver{std::move(ips), std::move(tokens)};
}
std::string
ProxyIpResolver::resolveClientIp(std::string const& connectionIp, HttpHeaders const& headers) const
{
if (proxyIps_.contains(connectionIp)) {
return extractClientIp(headers).value_or(connectionIp);
}
if (auto it = headers.find(kPROXY_TOKEN_HEADER); it != headers.end()) {
auto const tokenHash = util::sha256sum(it->value());
if (proxyTokens_.contains(tokenHash)) {
return extractClientIp(headers).value_or(connectionIp);
}
}
return connectionIp;
}
std::optional<std::string>
ProxyIpResolver::extractClientIp(HttpHeaders const& headers)
{
auto const it = headers.find(boost::beast::http::field::forwarded);
if (it == headers.end()) {
return std::nullopt;
}
// Forwarded header is case insensitive:
// https://developer.mozilla.org/en-US/docs/Web/HTTP/Reference/Headers/Forwarded#using_the_forwarded_header
auto const headerValue = util::toLower(it->value());
static constexpr std::string_view kFOR_PREFIX = "for=";
auto const startPos = headerValue.find(kFOR_PREFIX);
if (startPos == std::string::npos) {
return std::nullopt;
}
auto value = it->value().substr(startPos + kFOR_PREFIX.size());
static constexpr char kDELIMITER = ';';
auto const endPos = value.find(kDELIMITER);
auto const ip = value.substr(0, endPos);
static constexpr auto kMIN_IP_LENGTH = 7; // minimum 3 dots + 4 digits
if (ip.size() < kMIN_IP_LENGTH) {
return std::nullopt;
}
if (ip.starts_with('"')) {
return ip.substr(1, ip.size() - 2);
}
return ip;
}
} // namespace web

101
src/web/ProxyIpResolver.hpp Normal file
View File

@@ -0,0 +1,101 @@
//------------------------------------------------------------------------------
/*
This file is part of clio: https://github.com/XRPLF/clio
Copyright (c) 2025, the clio developers.
Permission to use, copy, modify, and distribute this software for any
purpose with or without fee is hereby granted, provided that the above
copyright notice and this permission notice appear in all copies.
THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
*/
//==============================================================================
#pragma once
#include "util/config/ConfigDefinition.hpp"
#include "util/config/ValueView.hpp"
#include <boost/beast/http/field.hpp>
#include <boost/beast/http/message.hpp>
#include <boost/beast/http/string_body.hpp>
#include <xrpl/basics/base_uint.h>
#include <optional>
#include <string>
#include <string_view>
#include <unordered_set>
namespace web {
/**
* @brief Resolves the client's IP address, considering proxy servers.
*
* This class is designed to determine the original IP address of a client when the connection
* is forwarded through a proxy server. It uses a configurable list of trusted proxy IPs
* and proxy tokens to decide whether to trust the `Forwarded` HTTP header.
*/
class ProxyIpResolver {
std::unordered_set<std::string> proxyIps_;
// ripple::uint256 doesn't have hash implementation
std::unordered_set<ripple::uint256, ripple::uint256::hasher> proxyTokens_;
public:
/**
* @brief Constructs a ProxyIpResolver.
*
* @param proxyIps A set of trusted proxy IP addresses.
* @param proxyTokens A set of trusted proxy tokens. The tokens will be hashed with SHA-256.
*/
ProxyIpResolver(std::unordered_set<std::string> proxyIps, std::unordered_set<std::string> proxyTokens);
/**
* @brief Creates a ProxyIpResolver from a configuration.
*
* The configuration should contain `server.proxy.ips` and `server.proxy.tokens` arrays.
*
* @param config The Clio configuration.
* @return A new ProxyIpResolver instance.
*/
static ProxyIpResolver
fromConfig(util::config::ClioConfigDefinition const& config);
using HttpHeaders = boost::beast::http::request<boost::beast::http::string_body>::header_type;
static constexpr std::string_view kPROXY_TOKEN_HEADER = "X-Proxy-Token";
/**
* @brief Resolves the client's IP address from the connection IP and HTTP headers.
*
* If the connection IP is in the trusted proxy list, or if a valid proxy token is provided in the headers,
* this method will attempt to extract the client's IP from the `Forwarded` header.
* Otherwise, it returns the connection IP.
*
* @param connectionIp The IP address of the direct connection.
* @param headers The HTTP request headers.
* @return The resolved client IP address as a string.
*/
std::string
resolveClientIp(std::string const& connectionIp, HttpHeaders const& headers) const;
private:
/**
* @brief Extracts the client IP from the `Forwarded` HTTP header.
*
* The `Forwarded` header is expected to be in the format specified by RFC 7239.
* This function looks for the `for` parameter.
*
* @param headers The HTTP request headers.
* @return The client IP address as a string if found, otherwise std::nullopt.
*/
static std::optional<std::string>
extractClientIp(HttpHeaders const& headers);
};
} // namespace web

View File

@@ -107,7 +107,7 @@ public:
void
operator()(std::string const& request, std::shared_ptr<web::ConnectionBase> const& connection)
{
if (not dosguard_.get().isOk(connection->clientIp)) {
if (not dosguard_.get().isOk(connection->clientIp())) {
connection->sendSlowDown(request);
return;
}
@@ -119,7 +119,7 @@ public:
if (not connection->upgraded and shouldReplaceParams(req))
req[JS(params)] = boost::json::array({boost::json::object{}});
if (not dosguard_.get().request(connection->clientIp, req)) {
if (not dosguard_.get().request(connection->clientIp(), req)) {
connection->sendSlowDown(request);
return;
}
@@ -128,7 +128,7 @@ public:
[this, request = std::move(req), connection](boost::asio::yield_context yield) mutable {
handleRequest(yield, std::move(request), connection);
},
connection->clientIp
connection->clientIp()
)) {
rpcEngine_->notifyTooBusy();
web::impl::ErrorHelper(connection).sendTooBusyError();
@@ -160,7 +160,7 @@ private:
{
LOG(log_.info()) << connection->tag() << (connection->upgraded ? "ws" : "http")
<< " received request from work queue: " << util::removeSecret(request)
<< " ip = " << connection->clientIp;
<< " ip = " << connection->clientIp();
try {
auto const range = backend_->fetchLedgerRange();
@@ -180,7 +180,7 @@ private:
connection->makeSubscriptionContext(tagFactory_),
tagFactory_.with(connection->tag()),
*range,
connection->clientIp,
connection->clientIp(),
std::cref(apiVersionParser_),
connection->isAdmin()
);
@@ -190,7 +190,7 @@ private:
request,
tagFactory_.with(connection->tag()),
*range,
connection->clientIp,
connection->clientIp(),
std::cref(apiVersionParser_),
connection->isAdmin()
);

View File

@@ -23,6 +23,7 @@
#include "util/log/Logger.hpp"
#include "web/AdminVerificationStrategy.hpp"
#include "web/HttpSession.hpp"
#include "web/ProxyIpResolver.hpp"
#include "web/SslHttpSession.hpp"
#include "web/dosguard/DOSGuardInterface.hpp"
#include "web/interface/Concepts.hpp"
@@ -87,6 +88,7 @@ class Detector : public std::enable_shared_from_this<Detector<PlainSessionType,
boost::beast::flat_buffer buffer_;
std::shared_ptr<AdminVerificationStrategy> const adminVerification_;
std::uint32_t maxWsSendingQueueSize_;
std::shared_ptr<ProxyIpResolver> proxyIpResolver_;
public:
/**
@@ -99,6 +101,7 @@ public:
* @param handler The server handler to use
* @param adminVerification The admin verification strategy to use
* @param maxWsSendingQueueSize The maximum size of the sending queue for websocket
* @param proxyIpResolver The client ip resolver if a request was forwarded by a proxy
*/
Detector(
tcp::socket&& socket,
@@ -107,7 +110,8 @@ public:
std::reference_wrapper<dosguard::DOSGuardInterface> dosGuard,
std::shared_ptr<HandlerType> handler,
std::shared_ptr<AdminVerificationStrategy> adminVerification,
std::uint32_t maxWsSendingQueueSize
std::uint32_t maxWsSendingQueueSize,
std::shared_ptr<ProxyIpResolver> proxyIpResolver
)
: stream_(std::move(socket))
, ctx_(ctx)
@@ -116,6 +120,7 @@ public:
, handler_(std::move(handler))
, adminVerification_(std::move(adminVerification))
, maxWsSendingQueueSize_(maxWsSendingQueueSize)
, proxyIpResolver_(std::move(proxyIpResolver))
{
}
@@ -169,6 +174,7 @@ public:
stream_.release_socket(),
ip,
adminVerification_,
proxyIpResolver_,
*ctx_,
tagFactory_,
dosGuard_,
@@ -184,6 +190,7 @@ public:
stream_.release_socket(),
ip,
adminVerification_,
proxyIpResolver_,
tagFactory_,
dosGuard_,
handler_,
@@ -219,6 +226,7 @@ class Server : public std::enable_shared_from_this<Server<PlainSessionType, SslS
tcp::acceptor acceptor_;
std::shared_ptr<AdminVerificationStrategy> adminVerification_;
std::uint32_t maxWsSendingQueueSize_;
std::shared_ptr<ProxyIpResolver> proxyIpResolver_;
public:
/**
@@ -232,6 +240,7 @@ public:
* @param handler The server handler to use
* @param adminVerification The admin verification strategy to use
* @param maxWsSendingQueueSize The maximum size of the sending queue for websocket
* @param proxyIpResolver The client ip resolver if a request was forwarded by a proxy
*/
Server(
boost::asio::io_context& ioc,
@@ -241,7 +250,8 @@ public:
dosguard::DOSGuardInterface& dosGuard,
std::shared_ptr<HandlerType> handler,
std::shared_ptr<AdminVerificationStrategy> adminVerification,
std::uint32_t maxWsSendingQueueSize
std::uint32_t maxWsSendingQueueSize,
ProxyIpResolver proxyIpResolver
)
: ioc_(std::ref(ioc))
, ctx_(std::move(ctx))
@@ -251,6 +261,7 @@ public:
, acceptor_(boost::asio::make_strand(ioc))
, adminVerification_(std::move(adminVerification))
, maxWsSendingQueueSize_(maxWsSendingQueueSize)
, proxyIpResolver_(std::make_shared<ProxyIpResolver>(std::move(proxyIpResolver)))
{
boost::beast::error_code ec;
@@ -310,7 +321,8 @@ private:
dosGuard_,
handler_,
adminVerification_,
maxWsSendingQueueSize_
maxWsSendingQueueSize_,
proxyIpResolver_
)
->run();
}
@@ -364,6 +376,8 @@ makeHttpServer(
// each ledger. we allow user delay 3 ledgers by default
auto const maxWsSendingQueueSize = serverConfig.get<uint32_t>("ws_max_sending_queue_size");
auto proxyIpResolver = ProxyIpResolver::fromConfig(config);
auto server = std::make_shared<HttpServer<HandlerType>>(
ioc,
std::move(expectedSslContext).value(),
@@ -372,7 +386,8 @@ makeHttpServer(
dosGuard,
handler,
std::move(expectedAdminVerification).value(),
maxWsSendingQueueSize
maxWsSendingQueueSize,
std::move(proxyIpResolver)
);
server->run();

View File

@@ -21,6 +21,7 @@
#include "util/Taggable.hpp"
#include "web/AdminVerificationStrategy.hpp"
#include "web/ProxyIpResolver.hpp"
#include "web/SslWsSession.hpp"
#include "web/dosguard/DOSGuardInterface.hpp"
#include "web/impl/HttpBase.hpp"
@@ -70,6 +71,7 @@ public:
* @param socket The socket. Ownership is transferred to HttpSession
* @param ip Client's IP address
* @param adminVerification The admin verification strategy to use
* @param proxyIpResolver The client ip resolver if a request was forwarded by a proxy
* @param ctx The SSL context
* @param tagFactory A factory that is used to generate tags to track requests and sessions
* @param dosGuard The denial of service guard to use
@@ -81,6 +83,7 @@ public:
tcp::socket&& socket,
std::string const& ip,
std::shared_ptr<AdminVerificationStrategy> const& adminVerification,
std::shared_ptr<ProxyIpResolver> proxyIpResolver,
boost::asio::ssl::context& ctx,
std::reference_wrapper<util::TagDecoratorFactory const> tagFactory,
std::reference_wrapper<dosguard::DOSGuardInterface> dosGuard,
@@ -92,6 +95,7 @@ public:
ip,
tagFactory,
adminVerification,
std::move(proxyIpResolver),
dosGuard,
handler,
std::move(buffer)
@@ -173,7 +177,7 @@ public:
{
std::make_shared<SslWsUpgrader<HandlerType>>(
std::move(stream_),
this->clientIp,
this->clientIp_,
tagFactory_,
this->dosGuard_,
this->handler_,

View File

@@ -26,6 +26,7 @@
#include "util/log/Logger.hpp"
#include "util/prometheus/Http.hpp"
#include "web/AdminVerificationStrategy.hpp"
#include "web/ProxyIpResolver.hpp"
#include "web/SubscriptionContextInterface.hpp"
#include "web/dosguard/DOSGuardInterface.hpp"
#include "web/interface/Concepts.hpp"
@@ -120,6 +121,7 @@ class HttpBase : public ConnectionBase {
std::shared_ptr<void> res_;
SendLambda sender_;
std::shared_ptr<AdminVerificationStrategy> adminVerification_;
std::shared_ptr<ProxyIpResolver> proxyIpResolver_;
protected:
boost::beast::flat_buffer buffer_;
@@ -164,6 +166,7 @@ public:
std::string const& ip,
std::reference_wrapper<util::TagDecoratorFactory const> tagFactory,
std::shared_ptr<AdminVerificationStrategy> adminVerification,
std::shared_ptr<ProxyIpResolver> proxyIpResolver,
std::reference_wrapper<dosguard::DOSGuardInterface> dosGuard,
std::shared_ptr<HandlerType> handler,
boost::beast::flat_buffer buffer
@@ -171,6 +174,7 @@ public:
: ConnectionBase(tagFactory, ip)
, sender_(*this)
, adminVerification_(std::move(adminVerification))
, proxyIpResolver_(std::move(proxyIpResolver))
, buffer_(std::move(buffer))
, dosGuard_(dosGuard)
, handler_(std::move(handler))
@@ -183,7 +187,7 @@ public:
{
LOG(perfLog_.debug()) << tag() << "http session closed";
if (not upgraded)
dosGuard_.get().decrement(this->clientIp);
dosGuard_.get().decrement(clientIp_);
}
void
@@ -218,11 +222,19 @@ public:
if (req_.method() == http::verb::get and req_.target() == "/health")
return sender_(httpResponse(http::status::ok, "text/html", kHEALTH_CHECK_HTML));
if (auto resolvedIp = proxyIpResolver_->resolveClientIp(clientIp_, req_); resolvedIp != clientIp_) {
LOG(log_.info()) << tag() << "Detected a forwarded request from proxy. Proxy ip: " << clientIp_
<< ". Resolved client ip: " << resolvedIp;
dosGuard_.get().decrement(clientIp_);
clientIp_ = std::move(resolvedIp);
dosGuard_.get().increment(clientIp_);
}
// Update isAdmin property of the connection
ConnectionBase::isAdmin_ = adminVerification_->isAdmin(req_, this->clientIp);
ConnectionBase::isAdmin_ = adminVerification_->isAdmin(req_, clientIp_);
if (boost::beast::websocket::is_upgrade(req_)) {
if (dosGuard_.get().isOk(this->clientIp)) {
if (dosGuard_.get().isOk(clientIp_)) {
// Disable the timeout. The websocket::stream uses its own timeout settings.
boost::beast::get_lowest_layer(derived().stream()).expires_never();
@@ -240,7 +252,7 @@ public:
return sender_(httpResponse(http::status::bad_request, "text/html", "Expected a POST request"));
}
LOG(log_.info()) << tag() << "Received request from ip = " << clientIp;
LOG(log_.info()) << tag() << "Received request from ip = " << clientIp_;
try {
(*handler_)(req_.body(), derived().shared_from_this());
@@ -271,7 +283,7 @@ public:
void
send(std::string&& msg, http::status status = http::status::ok) override
{
if (!dosGuard_.get().add(clientIp, msg.size())) {
if (!dosGuard_.get().add(clientIp_, msg.size())) {
auto jsonResponse = boost::json::parse(msg).as_object();
jsonResponse["warning"] = "load";
if (jsonResponse.contains("warnings") && jsonResponse["warnings"].is_array()) {

View File

@@ -124,7 +124,7 @@ public:
~WsBase() override
{
LOG(perfLog_.debug()) << tag() << "session closed";
dosGuard_.get().decrement(clientIp);
dosGuard_.get().decrement(clientIp_);
}
Derived<HandlerType>&
@@ -217,7 +217,7 @@ public:
void
send(std::string&& msg, http::status) override
{
if (!dosGuard_.get().add(clientIp, msg.size())) {
if (!dosGuard_.get().add(clientIp_, msg.size())) {
auto jsonResponse = boost::json::parse(msg).as_object();
jsonResponse["warning"] = "load";
@@ -283,7 +283,7 @@ public:
if (ec)
return wsFail(ec, "read");
LOG(perfLog_.info()) << tag() << "Received request from ip = " << this->clientIp;
LOG(perfLog_.info()) << tag() << "Received request from ip = " << clientIp_;
std::string requestStr{static_cast<char const*>(buffer_.data().data()), buffer_.size()};

View File

@@ -45,9 +45,9 @@ struct ConnectionBase : public util::Taggable {
protected:
boost::system::error_code ec_;
bool isAdmin_ = false;
std::string clientIp_;
public:
std::string const clientIp;
bool upgraded = false;
/**
@@ -57,7 +57,7 @@ public:
* @param ip The IP address of the connected peer
*/
ConnectionBase(util::TagDecoratorFactory const& tagFactory, std::string ip)
: Taggable(tagFactory), clientIp(std::move(ip))
: Taggable(tagFactory), clientIp_(std::move(ip))
{
}
@@ -119,5 +119,16 @@ public:
{
return isAdmin_;
}
/**
* @brief Get the IP address of the client.
*
* @return The IP address of the client.
*/
[[nodiscard]] std::string const&
clientIp() const
{
return clientIp_;
}
};
} // namespace web

View File

@@ -34,6 +34,7 @@
#include <memory>
#include <optional>
#include <string>
#include <utility>
namespace web::ng {
@@ -70,6 +71,17 @@ public:
std::string const&
ip() const;
/**
* @brief Set the ip of the client.
*
* @param newIp The new ip to set.
*/
void
setIp(std::string newIp)
{
ip_ = std::move(newIp);
}
/**
* @brief Get whether the client is an admin.
*

View File

@@ -25,6 +25,7 @@
#include "util/config/ConfigDefinition.hpp"
#include "util/config/ObjectView.hpp"
#include "util/log/Logger.hpp"
#include "web/ProxyIpResolver.hpp"
#include "web/ng/Connection.hpp"
#include "web/ng/MessageHandler.hpp"
#include "web/ng/ProcessingPolicy.hpp"
@@ -205,16 +206,16 @@ Server::Server(
ProcessingPolicy processingPolicy,
std::optional<size_t> parallelRequestLimit,
util::TagDecoratorFactory tagDecoratorFactory,
ProxyIpResolver proxyIpResolver,
std::optional<size_t> maxSubscriptionSendQueueSize,
OnConnectCheck onConnectCheck,
OnDisconnectHook onDisconnectHook
Hooks hooks
)
: ctx_{ctx}
, sslContext_{std::move(sslContext)}
, tagDecoratorFactory_{tagDecoratorFactory}
, connectionHandler_{processingPolicy, parallelRequestLimit, tagDecoratorFactory_, maxSubscriptionSendQueueSize, std::move(onDisconnectHook)}
, connectionHandler_{processingPolicy, parallelRequestLimit, tagDecoratorFactory_, maxSubscriptionSendQueueSize, std::move(proxyIpResolver), std::move(hooks.onDisconnectHook), std::move(hooks.onIpChangeHook)}
, endpoint_{std::move(endpoint)}
, onConnectCheck_{std::move(onConnectCheck)}
, onConnectCheck_{std::move(hooks.onConnectCheck)}
{
}
@@ -337,6 +338,7 @@ std::expected<Server, std::string>
makeServer(
util::config::ClioConfigDefinition const& config,
Server::OnConnectCheck onConnectCheck,
Server::OnIpChangeHook onIpChangeHook,
Server::OnDisconnectHook onDisconnectHook,
boost::asio::io_context& context
)
@@ -365,6 +367,8 @@ makeServer(
auto const maxSubscriptionSendQueueSize = serverConfig.get<size_t>("ws_max_sending_queue_size");
auto proxyIpResolver = ProxyIpResolver::fromConfig(config);
return std::expected<Server, std::string>{
std::in_place,
context,
@@ -373,9 +377,13 @@ makeServer(
processingPolicy,
parallelRequestLimit,
util::TagDecoratorFactory(config),
std::move(proxyIpResolver),
maxSubscriptionSendQueueSize,
std::move(onConnectCheck),
std::move(onDisconnectHook)
Server::Hooks{
.onConnectCheck = std::move(onConnectCheck),
.onIpChangeHook = std::move(onIpChangeHook),
.onDisconnectHook = std::move(onDisconnectHook)
}
};
}

View File

@@ -22,6 +22,7 @@
#include "util/Taggable.hpp"
#include "util/config/ConfigDefinition.hpp"
#include "util/log/Logger.hpp"
#include "web/ProxyIpResolver.hpp"
#include "web/ng/Connection.hpp"
#include "web/ng/MessageHandler.hpp"
#include "web/ng/ProcessingPolicy.hpp"
@@ -62,11 +63,22 @@ public:
*/
using OnConnectCheck = std::function<std::expected<void, Response>(Connection const&)>;
using OnIpChangeHook = impl::ConnectionHandler::OnIpChangeHook;
/**
* @brief Hook called when any connection disconnects
*/
using OnDisconnectHook = impl::ConnectionHandler::OnDisconnectHook;
/**
* @brief A struct that holds all the hooks for the server.
*/
struct Hooks {
OnConnectCheck onConnectCheck;
OnIpChangeHook onIpChangeHook;
OnDisconnectHook onDisconnectHook;
};
private:
util::Logger log_{"WebServer"};
util::Logger perfLog_{"Performance"};
@@ -94,9 +106,9 @@ public:
* @param parallelRequestLimit The limit of requests for one connection that can be processed in parallel. Only used
* if processingPolicy is parallel.
* @param tagDecoratorFactory The tag decorator factory.
* @param proxyIpResolver The client ip resolver if a request was forwarded by a proxy
* @param maxSubscriptionSendQueueSize The maximum size of the subscription send queue.
* @param onConnectCheck The check to perform on each connection.
* @param onDisconnectHook The hook to call on each disconnection.
* @param hooks The server hooks
*/
Server(
boost::asio::io_context& ctx,
@@ -105,9 +117,9 @@ public:
ProcessingPolicy processingPolicy,
std::optional<size_t> parallelRequestLimit,
util::TagDecoratorFactory tagDecoratorFactory,
ProxyIpResolver proxyIpResolver,
std::optional<size_t> maxSubscriptionSendQueueSize,
OnConnectCheck onConnectCheck,
OnDisconnectHook onDisconnectHook
Hooks hooks
);
/**
@@ -185,6 +197,7 @@ std::expected<Server, std::string>
makeServer(
util::config::ClioConfigDefinition const& config,
Server::OnConnectCheck onConnectCheck,
Server::OnIpChangeHook onIpChangeHook,
Server::OnDisconnectHook onDisconnectHook,
boost::asio::io_context& context
);

View File

@@ -23,6 +23,7 @@
#include "util/CoroutineGroup.hpp"
#include "util/Taggable.hpp"
#include "util/log/Logger.hpp"
#include "web/ProxyIpResolver.hpp"
#include "web/SubscriptionContextInterface.hpp"
#include "web/ng/Connection.hpp"
#include "web/ng/Error.hpp"
@@ -90,13 +91,17 @@ ConnectionHandler::ConnectionHandler(
std::optional<size_t> maxParallelRequests,
util::TagDecoratorFactory& tagFactory,
std::optional<size_t> maxSubscriptionSendQueueSize,
OnDisconnectHook onDisconnectHook
ProxyIpResolver proxyIpResolver,
OnDisconnectHook onDisconnectHook,
OnIpChangeHook onIpChangeHook
)
: processingPolicy_{processingPolicy}
, maxParallelRequests_{maxParallelRequests}
, tagFactory_{tagFactory}
, maxSubscriptionSendQueueSize_{maxSubscriptionSendQueueSize}
, proxyIpResolver_(std::move(proxyIpResolver))
, onDisconnectHook_{std::move(onDisconnectHook)}
, onIpChangeHook_(std::move(onIpChangeHook))
{
}
@@ -277,9 +282,9 @@ ConnectionHandler::sequentRequestResponseLoop(
if (not expectedRequest)
return handleError(expectedRequest.error(), connection);
LOG(log_.info()) << connection.tag() << "Received request from ip = " << connection.ip();
resolveClientIp(connection, *expectedRequest);
auto maybeReturnValue = processRequest(connection, subscriptionContext, expectedRequest.value(), yield);
auto maybeReturnValue = processRequest(connection, subscriptionContext, *expectedRequest, yield);
if (maybeReturnValue.has_value())
return maybeReturnValue.value();
}
@@ -308,6 +313,8 @@ ConnectionHandler::parallelRequestResponseLoop(
break;
}
resolveClientIp(connection, *expectedRequest);
if (not tasksGroup.isFull()) {
bool const spawnSuccess = tasksGroup.spawn(
yield, // spawn on the same strand
@@ -385,4 +392,16 @@ ConnectionHandler::handleRequest(
}
}
void
ConnectionHandler::resolveClientIp(Connection& connection, Request const& request) const
{
if (auto resolvedClientIp = proxyIpResolver_.resolveClientIp(connection.ip(), request.httpHeaders());
resolvedClientIp != connection.ip()) {
LOG(log_.info()) << connection.tag() << "Detected a forwarded request from proxy. Proxy ip: " << connection.ip()
<< ". Resolved client ip: " << resolvedClientIp;
onIpChangeHook_(connection.ip(), resolvedClientIp);
connection.setIp(std::move(resolvedClientIp));
}
}
} // namespace web::ng::impl

View File

@@ -26,6 +26,7 @@
#include "util/prometheus/Gauge.hpp"
#include "util/prometheus/Label.hpp"
#include "util/prometheus/Prometheus.hpp"
#include "web/ProxyIpResolver.hpp"
#include "web/SubscriptionContextInterface.hpp"
#include "web/ng/Connection.hpp"
#include "web/ng/Error.hpp"
@@ -52,6 +53,7 @@ namespace web::ng::impl {
class ConnectionHandler {
public:
using OnDisconnectHook = std::function<void(Connection const&)>;
using OnIpChangeHook = std::function<void(std::string const&, std::string const&)>;
using TargetToHandlerMap = std::unordered_map<std::string, MessageHandler, util::StringHash, std::equal_to<>>;
private:
@@ -64,7 +66,10 @@ private:
std::reference_wrapper<util::TagDecoratorFactory> tagFactory_;
std::optional<size_t> maxSubscriptionSendQueueSize_;
ProxyIpResolver proxyIpResolver_;
OnDisconnectHook onDisconnectHook_;
OnIpChangeHook onIpChangeHook_;
TargetToHandlerMap getHandlers_;
TargetToHandlerMap postHandlers_;
@@ -84,7 +89,9 @@ public:
std::optional<size_t> maxParallelRequests,
util::TagDecoratorFactory& tagFactory,
std::optional<size_t> maxSubscriptionSendQueueSize,
OnDisconnectHook onDisconnectHook
ProxyIpResolver proxyIpResolver,
OnDisconnectHook onDisconnectHook,
OnIpChangeHook onIpChangeHook
);
ConnectionHandler(ConnectionHandler&&) = delete;
@@ -159,6 +166,9 @@ private:
Request const& request,
boost::asio::yield_context yield
);
void
resolveClientIp(Connection& connection, Request const& request) const;
};
} // namespace web::ng::impl

View File

@@ -50,6 +50,7 @@
#include <string>
#include <string_view>
#include <utility>
#include <variant>
#include <vector>
namespace http = boost::beast::http;
@@ -81,8 +82,8 @@ syncRequest(
req.set(http::field::host, host);
req.set(http::field::user_agent, BOOST_BEAST_VERSION_STRING);
for (auto& header : additionalHeaders) {
req.set(header.name, header.value);
for (auto const& header : additionalHeaders) {
std::visit([&header, &req](auto const& name) { req.set(name, header.value); }, header.name);
}
req.target(target);
@@ -106,6 +107,10 @@ WebHeader::WebHeader(http::field name, std::string value) : name(name), value(st
{
}
WebHeader::WebHeader(std::string_view name, std::string value) : name(std::string{name}), value(std::move(value))
{
}
std::pair<boost::beast::http::status, std::string>
HttpSyncClient::post(
std::string const& host,

View File

@@ -35,12 +35,14 @@
#include <string>
#include <string_view>
#include <utility>
#include <variant>
#include <vector>
struct WebHeader {
WebHeader(boost::beast::http::field name, std::string value);
WebHeader(std::string_view name, std::string value);
boost::beast::http::field name;
std::variant<boost::beast::http::field, std::string> name;
std::string value;
};

View File

@@ -46,6 +46,7 @@
#include <string>
#include <string_view>
#include <utility>
#include <variant>
#include <vector>
namespace http = boost::beast::http;
@@ -69,7 +70,7 @@ WebSocketSyncClient::connect(std::string const& host, std::string const& port, s
[additionalHeaders = std::move(additionalHeaders)](boost::beast::websocket::request_type& req) {
req.set(http::field::user_agent, std::string(BOOST_BEAST_VERSION_STRING) + " websocket-client-coro");
for (auto const& header : additionalHeaders) {
req.set(header.name, header.value);
std::visit([&header, &req](auto const& name) { req.set(name, header.value); }, header.name);
}
}
)
@@ -164,7 +165,7 @@ WebSocketAsyncClient::connect(
boost::beast::websocket::stream_base::decorator(
[additionalHeaders = std::move(additionalHeaders)](boost::beast::websocket::request_type& req) {
for (auto const& header : additionalHeaders) {
req.set(header.name, header.value);
std::visit([&header, &req](auto const& name) { req.set(name, header.value); }, header.name);
}
}
)

View File

@@ -179,6 +179,7 @@ target_sources(
util/RetryTests.cpp
util/ScopeGuardTests.cpp
util/SignalsHandlerTests.cpp
util/ShasumTests.cpp
util/SpawnTests.cpp
util/StopHelperTests.cpp
util/TimeUtilsTests.cpp
@@ -201,6 +202,7 @@ target_sources(
web/ng/impl/HttpConnectionTests.cpp
web/ng/impl/ServerSslContextTests.cpp
web/ng/impl/WsConnectionTests.cpp
web/ProxyIpResolverTests.cpp
web/RPCServerHandlerTests.cpp
web/ServerTests.cpp
web/SubscriptionContextTests.cpp

View File

@@ -92,6 +92,21 @@ TEST_F(OnConnectCheckTests, RateLimited)
EXPECT_EQ(httpResponse.body(), "Too many requests");
}
struct IpChangeHookTests : WebHandlersTest {
IpChangeHook ipChangeHook{dosGuardMock};
};
TEST_F(IpChangeHookTests, CallsDecrementAndIncrement)
{
std::string const oldIp = "old ip";
std::string const newIp = "new ip";
EXPECT_CALL(dosGuardMock, decrement(oldIp));
EXPECT_CALL(dosGuardMock, increment(newIp));
ipChangeHook(oldIp, newIp);
}
struct DisconnectHookTests : WebHandlersTest {
DisconnectHook disconnectHook{dosGuardMock};
};

View File

@@ -0,0 +1,47 @@
//------------------------------------------------------------------------------
/*
This file is part of clio: https://github.com/XRPLF/clio
Copyright (c) 2025, the clio developers.
Permission to use, copy, modify, and distribute this software for any
purpose with or without fee is hereby granted, provided that the above
copyright notice and this permission notice appear in all copies.
THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
*/
//==============================================================================
#include "util/Shasum.hpp"
#include <gtest/gtest.h>
#include <xrpl/basics/base_uint.h>
using namespace util;
struct ShasumTest : testing::Test {
static constexpr auto kEMPTY_HASH = "E3B0C44298FC1C149AFBF4C8996FB92427AE41E4649B934CA495991B7852B855";
static constexpr auto kHELLO_WORLD_HASH = "B94D27B9934D3E08A52E52D7DA7DABFAC484EFE37A5380EE9088F7ACE2EFCDE9";
};
TEST_F(ShasumTest, sha256sum)
{
ripple::uint256 expected;
ASSERT_TRUE(expected.parseHex(kEMPTY_HASH));
EXPECT_EQ(sha256sum(""), expected);
ASSERT_TRUE(expected.parseHex(kHELLO_WORLD_HASH));
EXPECT_EQ(sha256sum("hello world"), expected);
}
TEST_F(ShasumTest, sha256sumString)
{
EXPECT_EQ(sha256sumString(""), kEMPTY_HASH);
EXPECT_EQ(sha256sumString("hello world"), kHELLO_WORLD_HASH);
}

View File

@@ -0,0 +1,214 @@
/*
This file is part of clio: https://github.com/XRPLF/clio
Copyright (c) 2025, the clio developers.
Permission to use, copy, modify, and distribute this software for any
purpose with or without fee is hereby granted, provided that the above
copyright notice and this permission notice appear in all copies.
THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
*/
//==============================================================================
#include "util/NameGenerator.hpp"
#include "util/config/Array.hpp"
#include "util/config/ConfigDefinition.hpp"
#include "util/config/ConfigFileJson.hpp"
#include "util/config/ConfigValue.hpp"
#include "util/config/Types.hpp"
#include "web/ProxyIpResolver.hpp"
#include <boost/beast/http/field.hpp>
#include <boost/json/parse.hpp>
#include <fmt/format.h>
#include <gtest/gtest.h>
#include <string>
#include <unordered_set>
#include <utility>
#include <vector>
namespace http = boost::beast::http;
using namespace web;
struct ProxyIpResolverTestParams {
std::string testName;
std::unordered_set<std::string> proxyIps;
std::unordered_set<std::string> proxyTokens;
std::vector<std::pair<std::string, std::string>> headers;
std::string connectionIp;
std::string expectedIp;
};
class ProxyIpResolverTest : public ::testing::TestWithParam<ProxyIpResolverTestParams> {};
TEST_F(ProxyIpResolverTest, FromConfig)
{
using namespace util::config;
ClioConfigDefinition config{{
{"server.proxy.ips.[]", Array{ConfigValue{ConfigType::String}}},
{"server.proxy.tokens.[]", Array{ConfigValue{ConfigType::String}}},
}};
auto const proxyIp = "1.2.3.4";
auto const clientIp = "5.6.7.8";
auto const proxyToken = "some_proxy_token";
auto const configStr = fmt::format(
R"({{
"server": {{
"proxy": {{
"ips": ["{}"],
"tokens": ["{}"]
}}
}}
}})",
proxyIp,
proxyToken
);
auto const err = config.parse(ConfigFileJson{boost::json::parse(configStr).as_object()});
ASSERT_FALSE(err.has_value());
auto const proxyIpResolver = ProxyIpResolver::fromConfig(config);
ProxyIpResolver::HttpHeaders headers;
EXPECT_EQ(proxyIpResolver.resolveClientIp(clientIp, headers), clientIp);
EXPECT_EQ(proxyIpResolver.resolveClientIp(proxyIp, headers), proxyIp);
headers.set(boost::beast::http::field::forwarded, fmt::format("for={}", clientIp));
EXPECT_EQ(proxyIpResolver.resolveClientIp(clientIp, headers), clientIp);
EXPECT_EQ(proxyIpResolver.resolveClientIp(proxyIp, headers), clientIp);
headers.set(ProxyIpResolver::kPROXY_TOKEN_HEADER, proxyToken);
EXPECT_EQ(proxyIpResolver.resolveClientIp(clientIp, headers), clientIp);
EXPECT_EQ(proxyIpResolver.resolveClientIp(proxyIp, headers), clientIp);
EXPECT_EQ(proxyIpResolver.resolveClientIp("127.0.0.1", headers), clientIp);
}
TEST_P(ProxyIpResolverTest, ResolveClientIp)
{
auto const& params = GetParam();
ProxyIpResolver resolver(params.proxyIps, params.proxyTokens);
ProxyIpResolver::HttpHeaders headers;
for (auto const& [key, value] : params.headers) {
headers.set(key, value);
}
EXPECT_EQ(resolver.resolveClientIp(params.connectionIp, headers), params.expectedIp);
}
INSTANTIATE_TEST_SUITE_P(
ProxyIpResolverTests,
ProxyIpResolverTest,
::testing::Values(
ProxyIpResolverTestParams{
.testName = "NoProxy",
.proxyIps = {},
.proxyTokens = {},
.headers = {},
.connectionIp = "1.2.3.4",
.expectedIp = "1.2.3.4"
},
ProxyIpResolverTestParams{
.testName = "TrustedProxyIpWithForwardedHeader",
.proxyIps = {"5.6.7.8"},
.proxyTokens = {},
.headers = {{std::string(http::to_string(http::field::forwarded)), "for=1.2.3.4"}},
.connectionIp = "5.6.7.8",
.expectedIp = "1.2.3.4"
},
ProxyIpResolverTestParams{
.testName = "TrustedProxyIpWithoutForwardedHeader",
.proxyIps = {"5.6.7.8"},
.proxyTokens = {},
.headers = {},
.connectionIp = "5.6.7.8",
.expectedIp = "5.6.7.8"
},
ProxyIpResolverTestParams{
.testName = "UntrustedProxyIpWithForwardedHeader",
.proxyIps = {},
.proxyTokens = {},
.headers = {{std::string(http::to_string(http::field::forwarded)), "for=1.2.3.4"}},
.connectionIp = "5.6.7.8",
.expectedIp = "5.6.7.8"
},
ProxyIpResolverTestParams{
.testName = "TrustedProxyTokenWithForwardedHeader",
.proxyIps = {},
.proxyTokens = {"test_token"},
.headers =
{{std::string(ProxyIpResolver::kPROXY_TOKEN_HEADER), "test_token"},
{std::string(http::to_string(http::field::forwarded)), "for=1.2.3.4"}},
.connectionIp = "5.6.7.8",
.expectedIp = "1.2.3.4"
},
ProxyIpResolverTestParams{
.testName = "TrustedProxyTokenWithoutForwardedHeader",
.proxyIps = {},
.proxyTokens = {"test_token"},
.headers = {{std::string(ProxyIpResolver::kPROXY_TOKEN_HEADER), "test_token"}},
.connectionIp = "5.6.7.8",
.expectedIp = "5.6.7.8"
},
ProxyIpResolverTestParams{
.testName = "UntrustedProxyTokenWithForwardedHeader",
.proxyIps = {},
.proxyTokens = {},
.headers =
{{std::string(ProxyIpResolver::kPROXY_TOKEN_HEADER), "test_token"},
{std::string(http::to_string(http::field::forwarded)), "for=1.2.3.4"}},
.connectionIp = "5.6.7.8",
.expectedIp = "5.6.7.8"
},
ProxyIpResolverTestParams{
.testName = "ForwardedHeaderWithAdditionalFields",
.proxyIps = {"5.6.7.8"},
.proxyTokens = {},
.headers =
{{std::string(http::to_string(http::field::forwarded)),
"by=203.0.113.43; for=1.2.3.4; host=example.com; proto=https"}},
.connectionIp = "5.6.7.8",
.expectedIp = "1.2.3.4"
},
ProxyIpResolverTestParams{
.testName = "ForwardedHeaderWithDifferentCase",
.proxyIps = {"5.6.7.8"},
.proxyTokens = {},
.headers = {{std::string(http::to_string(http::field::forwarded)), "For=1.2.3.4"}},
.connectionIp = "5.6.7.8",
.expectedIp = "1.2.3.4"
},
ProxyIpResolverTestParams{
.testName = "ForwardedHeaderWithoutFor",
.proxyIps = {"5.6.7.8"},
.proxyTokens = {},
.headers = {{std::string(http::to_string(http::field::forwarded)), "by=1.2.3.4"}},
.connectionIp = "5.6.7.8",
.expectedIp = "5.6.7.8"
},
ProxyIpResolverTestParams{
.testName = "ForwardedHeaderWithIpInQuotes",
.proxyIps = {"5.6.7.8"},
.proxyTokens = {},
.headers = {{std::string(http::to_string(http::field::forwarded)), "for=\"1.2.3.4\""}},
.connectionIp = "5.6.7.8",
.expectedIp = "1.2.3.4"
},
ProxyIpResolverTestParams{
.testName = "ForwardedHeaderIsIncorrect",
.proxyIps = {"5.6.7.8"},
.proxyTokens = {},
.headers = {{std::string(http::to_string(http::field::forwarded)), "for=\";some_other_text"}},
.connectionIp = "5.6.7.8",
.expectedIp = "5.6.7.8"
}
),
tests::util::kNAME_GENERATOR
);

View File

@@ -132,8 +132,8 @@ TEST_F(WebRPCServerHandlerTest, HTTPDefaultPath)
]
})JSON";
EXPECT_CALL(dosguard, isOk(session->clientIp)).WillOnce(testing::Return(true));
EXPECT_CALL(dosguard, request(session->clientIp, boost::json::parse(kREQUEST).as_object()))
EXPECT_CALL(dosguard, isOk(session->clientIp())).WillOnce(testing::Return(true));
EXPECT_CALL(dosguard, request(session->clientIp(), boost::json::parse(kREQUEST).as_object()))
.WillOnce(testing::Return(true));
EXPECT_CALL(*rpcEngine, buildResponse(testing::_))
@@ -153,7 +153,7 @@ TEST_F(WebRPCServerHandlerTest, HTTPRejectedByDosguard)
"params": [{}]
})JSON";
EXPECT_CALL(dosguard, isOk(session->clientIp)).WillOnce(testing::Return(false));
EXPECT_CALL(dosguard, isOk(session->clientIp())).WillOnce(testing::Return(false));
(*handler)(kREQUEST, session);
EXPECT_EQ(session->slowDownCallsCounter, 1);
@@ -166,8 +166,8 @@ TEST_F(WebRPCServerHandlerTest, HTTPRejectedByDosguardAfterParsing)
"params": [{}]
})JSON";
EXPECT_CALL(dosguard, isOk(session->clientIp)).WillOnce(testing::Return(true));
EXPECT_CALL(dosguard, request(session->clientIp, testing::_)).WillOnce(testing::Return(false));
EXPECT_CALL(dosguard, isOk(session->clientIp())).WillOnce(testing::Return(true));
EXPECT_CALL(dosguard, request(session->clientIp(), testing::_)).WillOnce(testing::Return(false));
(*handler)(kREQUEST, session);
EXPECT_EQ(session->slowDownCallsCounter, 1);
@@ -198,8 +198,8 @@ TEST_F(WebRPCServerHandlerTest, WsNormalPath)
}
]
})JSON";
EXPECT_CALL(dosguard, isOk(session->clientIp)).WillOnce(testing::Return(true));
EXPECT_CALL(dosguard, request(session->clientIp, boost::json::parse(kREQUEST).as_object()))
EXPECT_CALL(dosguard, isOk(session->clientIp())).WillOnce(testing::Return(true));
EXPECT_CALL(dosguard, request(session->clientIp(), boost::json::parse(kREQUEST).as_object()))
.WillOnce(testing::Return(true));
EXPECT_CALL(*rpcEngine, buildResponse(testing::_))
@@ -221,7 +221,7 @@ TEST_F(WebRPCServerHandlerTest, WsRejectedByDosguard)
"api_version": 2
})JSON";
EXPECT_CALL(dosguard, isOk(session->clientIp)).WillOnce(testing::Return(false));
EXPECT_CALL(dosguard, isOk(session->clientIp())).WillOnce(testing::Return(false));
(*handler)(kREQUEST, session);
EXPECT_EQ(session->slowDownCallsCounter, 1);
@@ -236,8 +236,8 @@ TEST_F(WebRPCServerHandlerTest, WsRejectedByDosguardAfterParsing)
"api_version": 2
})JSON";
EXPECT_CALL(dosguard, isOk(session->clientIp)).WillOnce(testing::Return(true));
EXPECT_CALL(dosguard, request(session->clientIp, boost::json::parse(kREQUEST).as_object()))
EXPECT_CALL(dosguard, isOk(session->clientIp())).WillOnce(testing::Return(true));
EXPECT_CALL(dosguard, request(session->clientIp(), boost::json::parse(kREQUEST).as_object()))
.WillOnce(testing::Return(false));
(*handler)(kREQUEST, session);
@@ -274,8 +274,8 @@ TEST_F(WebRPCServerHandlerTest, HTTPForwardedPath)
]
})JSON";
EXPECT_CALL(dosguard, isOk(session->clientIp)).WillOnce(testing::Return(true));
EXPECT_CALL(dosguard, request(session->clientIp, boost::json::parse(kREQUEST).as_object()))
EXPECT_CALL(dosguard, isOk(session->clientIp())).WillOnce(testing::Return(true));
EXPECT_CALL(dosguard, request(session->clientIp(), boost::json::parse(kREQUEST).as_object()))
.WillOnce(testing::Return(true));
EXPECT_CALL(*rpcEngine, buildResponse(testing::_))
@@ -323,8 +323,8 @@ TEST_F(WebRPCServerHandlerTest, HTTPForwardedErrorPath)
]
})JSON";
EXPECT_CALL(dosguard, isOk(session->clientIp)).WillOnce(testing::Return(true));
EXPECT_CALL(dosguard, request(session->clientIp, boost::json::parse(kREQUEST).as_object()))
EXPECT_CALL(dosguard, isOk(session->clientIp())).WillOnce(testing::Return(true));
EXPECT_CALL(dosguard, request(session->clientIp(), boost::json::parse(kREQUEST).as_object()))
.WillOnce(testing::Return(true));
EXPECT_CALL(*rpcEngine, buildResponse(testing::_))
@@ -370,8 +370,8 @@ TEST_F(WebRPCServerHandlerTest, WsForwardedPath)
]
})JSON";
EXPECT_CALL(dosguard, isOk(session->clientIp)).WillOnce(testing::Return(true));
EXPECT_CALL(dosguard, request(session->clientIp, boost::json::parse(kREQUEST).as_object()))
EXPECT_CALL(dosguard, isOk(session->clientIp())).WillOnce(testing::Return(true));
EXPECT_CALL(dosguard, request(session->clientIp(), boost::json::parse(kREQUEST).as_object()))
.WillOnce(testing::Return(true));
EXPECT_CALL(*rpcEngine, buildResponse(testing::_))
@@ -420,8 +420,8 @@ TEST_F(WebRPCServerHandlerTest, WsForwardedErrorPath)
]
})JSON";
EXPECT_CALL(dosguard, isOk(session->clientIp)).WillOnce(testing::Return(true));
EXPECT_CALL(dosguard, request(session->clientIp, boost::json::parse(kREQUEST).as_object()))
EXPECT_CALL(dosguard, isOk(session->clientIp())).WillOnce(testing::Return(true));
EXPECT_CALL(dosguard, request(session->clientIp(), boost::json::parse(kREQUEST).as_object()))
.WillOnce(testing::Return(true));
EXPECT_CALL(*rpcEngine, buildResponse(testing::_))
@@ -472,8 +472,8 @@ TEST_F(WebRPCServerHandlerTest, HTTPErrorPath)
]
})JSON";
EXPECT_CALL(dosguard, isOk(session->clientIp)).WillOnce(testing::Return(true));
EXPECT_CALL(dosguard, request(session->clientIp, boost::json::parse(kREQUEST_JSON).as_object()))
EXPECT_CALL(dosguard, isOk(session->clientIp())).WillOnce(testing::Return(true));
EXPECT_CALL(dosguard, request(session->clientIp(), boost::json::parse(kREQUEST_JSON).as_object()))
.WillOnce(testing::Return(true));
EXPECT_CALL(*rpcEngine, buildResponse(testing::_))
@@ -521,8 +521,8 @@ TEST_F(WebRPCServerHandlerTest, WsErrorPath)
"api_version": 2
})JSON";
EXPECT_CALL(dosguard, isOk(session->clientIp)).WillOnce(testing::Return(true));
EXPECT_CALL(dosguard, request(session->clientIp, boost::json::parse(kREQUEST_JSON).as_object()))
EXPECT_CALL(dosguard, isOk(session->clientIp())).WillOnce(testing::Return(true));
EXPECT_CALL(dosguard, request(session->clientIp(), boost::json::parse(kREQUEST_JSON).as_object()))
.WillOnce(testing::Return(true));
EXPECT_CALL(*rpcEngine, buildResponse(testing::_))
@@ -557,8 +557,8 @@ TEST_F(WebRPCServerHandlerTest, HTTPNotReady)
}
})JSON";
EXPECT_CALL(dosguard, isOk(session->clientIp)).WillOnce(testing::Return(true));
EXPECT_CALL(dosguard, request(session->clientIp, boost::json::parse(kREQUEST).as_object()))
EXPECT_CALL(dosguard, isOk(session->clientIp())).WillOnce(testing::Return(true));
EXPECT_CALL(dosguard, request(session->clientIp(), boost::json::parse(kREQUEST).as_object()))
.WillOnce(testing::Return(true));
EXPECT_CALL(*rpcEngine, notifyNotReady).Times(1);
@@ -589,8 +589,8 @@ TEST_F(WebRPCServerHandlerTest, WsNotReady)
}
})JSON";
EXPECT_CALL(dosguard, isOk(session->clientIp)).WillOnce(testing::Return(true));
EXPECT_CALL(dosguard, request(session->clientIp, boost::json::parse(kREQUEST).as_object()))
EXPECT_CALL(dosguard, isOk(session->clientIp())).WillOnce(testing::Return(true));
EXPECT_CALL(dosguard, request(session->clientIp(), boost::json::parse(kREQUEST).as_object()))
.WillOnce(testing::Return(true));
EXPECT_CALL(*rpcEngine, notifyNotReady).Times(1);
@@ -619,8 +619,8 @@ TEST_F(WebRPCServerHandlerTest, HTTPBadSyntaxWhenRequestSubscribe)
}
})JSON";
EXPECT_CALL(dosguard, isOk(session->clientIp)).WillOnce(testing::Return(true));
EXPECT_CALL(dosguard, request(session->clientIp, testing::_)).WillOnce(testing::Return(true));
EXPECT_CALL(dosguard, isOk(session->clientIp())).WillOnce(testing::Return(true));
EXPECT_CALL(dosguard, request(session->clientIp(), testing::_)).WillOnce(testing::Return(true));
EXPECT_CALL(*rpcEngine, notifyBadSyntax).Times(1);
@@ -636,8 +636,8 @@ TEST_F(WebRPCServerHandlerTest, HTTPMissingCommand)
static constexpr auto kRESPONSE = "Null method";
EXPECT_CALL(dosguard, isOk(session->clientIp)).WillOnce(testing::Return(true));
EXPECT_CALL(dosguard, request(session->clientIp, testing::_)).WillOnce(testing::Return(true));
EXPECT_CALL(dosguard, isOk(session->clientIp())).WillOnce(testing::Return(true));
EXPECT_CALL(dosguard, request(session->clientIp(), testing::_)).WillOnce(testing::Return(true));
EXPECT_CALL(*rpcEngine, notifyBadSyntax).Times(1);
@@ -654,8 +654,8 @@ TEST_F(WebRPCServerHandlerTest, HTTPCommandNotString)
static constexpr auto kRESPONSE = "method is not string";
EXPECT_CALL(dosguard, isOk(session->clientIp)).WillOnce(testing::Return(true));
EXPECT_CALL(dosguard, request(session->clientIp, testing::_)).WillOnce(testing::Return(true));
EXPECT_CALL(dosguard, isOk(session->clientIp())).WillOnce(testing::Return(true));
EXPECT_CALL(dosguard, request(session->clientIp(), testing::_)).WillOnce(testing::Return(true));
EXPECT_CALL(*rpcEngine, notifyBadSyntax).Times(1);
@@ -672,8 +672,8 @@ TEST_F(WebRPCServerHandlerTest, HTTPCommandIsEmpty)
static constexpr auto kRESPONSE = "method is empty";
EXPECT_CALL(dosguard, isOk(session->clientIp)).WillOnce(testing::Return(true));
EXPECT_CALL(dosguard, request(session->clientIp, testing::_)).WillOnce(testing::Return(true));
EXPECT_CALL(dosguard, isOk(session->clientIp())).WillOnce(testing::Return(true));
EXPECT_CALL(dosguard, request(session->clientIp(), testing::_)).WillOnce(testing::Return(true));
EXPECT_CALL(*rpcEngine, notifyBadSyntax).Times(1);
@@ -705,8 +705,8 @@ TEST_F(WebRPCServerHandlerTest, WsMissingCommand)
}
})JSON";
EXPECT_CALL(dosguard, isOk(session->clientIp)).WillOnce(testing::Return(true));
EXPECT_CALL(dosguard, request(session->clientIp, boost::json::parse(kREQUEST).as_object()))
EXPECT_CALL(dosguard, isOk(session->clientIp())).WillOnce(testing::Return(true));
EXPECT_CALL(dosguard, request(session->clientIp(), boost::json::parse(kREQUEST).as_object()))
.WillOnce(testing::Return(true));
EXPECT_CALL(*rpcEngine, notifyBadSyntax).Times(1);
@@ -726,8 +726,8 @@ TEST_F(WebRPCServerHandlerTest, HTTPParamsUnparsableNotArray)
"params": "wrong"
})JSON";
EXPECT_CALL(dosguard, isOk(session->clientIp)).WillOnce(testing::Return(true));
EXPECT_CALL(dosguard, request(session->clientIp, testing::_)).WillOnce(testing::Return(true));
EXPECT_CALL(dosguard, isOk(session->clientIp())).WillOnce(testing::Return(true));
EXPECT_CALL(dosguard, request(session->clientIp(), testing::_)).WillOnce(testing::Return(true));
EXPECT_CALL(*rpcEngine, notifyBadSyntax).Times(1);
@@ -747,8 +747,8 @@ TEST_F(WebRPCServerHandlerTest, HTTPParamsUnparsableArrayWithDigit)
"params": [1]
})JSON";
EXPECT_CALL(dosguard, isOk(session->clientIp)).WillOnce(testing::Return(true));
EXPECT_CALL(dosguard, request(session->clientIp, testing::_)).WillOnce(testing::Return(true));
EXPECT_CALL(dosguard, isOk(session->clientIp())).WillOnce(testing::Return(true));
EXPECT_CALL(dosguard, request(session->clientIp(), testing::_)).WillOnce(testing::Return(true));
EXPECT_CALL(*rpcEngine, notifyBadSyntax).Times(1);
@@ -780,8 +780,8 @@ TEST_F(WebRPCServerHandlerTest, HTTPInternalError)
"params": [{}]
})JSON";
EXPECT_CALL(dosguard, isOk(session->clientIp)).WillOnce(testing::Return(true));
EXPECT_CALL(dosguard, request(session->clientIp, boost::json::parse(kREQUEST_JSON).as_object()))
EXPECT_CALL(dosguard, isOk(session->clientIp())).WillOnce(testing::Return(true));
EXPECT_CALL(dosguard, request(session->clientIp(), boost::json::parse(kREQUEST_JSON).as_object()))
.WillOnce(testing::Return(true));
EXPECT_CALL(*rpcEngine, notifyInternalError).Times(1);
@@ -815,8 +815,8 @@ TEST_F(WebRPCServerHandlerTest, WsInternalError)
"id": "123"
})JSON";
EXPECT_CALL(dosguard, isOk(session->clientIp)).WillOnce(testing::Return(true));
EXPECT_CALL(dosguard, request(session->clientIp, boost::json::parse(kREQUEST_JSON).as_object()))
EXPECT_CALL(dosguard, isOk(session->clientIp())).WillOnce(testing::Return(true));
EXPECT_CALL(dosguard, request(session->clientIp(), boost::json::parse(kREQUEST_JSON).as_object()))
.WillOnce(testing::Return(true));
EXPECT_CALL(*rpcEngine, notifyInternalError).Times(1);
@@ -852,8 +852,8 @@ TEST_F(WebRPCServerHandlerTest, HTTPOutDated)
]
})JSON";
EXPECT_CALL(dosguard, isOk(session->clientIp)).WillOnce(testing::Return(true));
EXPECT_CALL(dosguard, request(session->clientIp, boost::json::parse(kREQUEST).as_object()))
EXPECT_CALL(dosguard, isOk(session->clientIp())).WillOnce(testing::Return(true));
EXPECT_CALL(dosguard, request(session->clientIp(), boost::json::parse(kREQUEST).as_object()))
.WillOnce(testing::Return(true));
EXPECT_CALL(*rpcEngine, buildResponse(testing::_))
@@ -895,8 +895,8 @@ TEST_F(WebRPCServerHandlerTest, WsOutdated)
]
})JSON";
EXPECT_CALL(dosguard, isOk(session->clientIp)).WillOnce(testing::Return(true));
EXPECT_CALL(dosguard, request(session->clientIp, boost::json::parse(kREQUEST).as_object()))
EXPECT_CALL(dosguard, isOk(session->clientIp())).WillOnce(testing::Return(true));
EXPECT_CALL(dosguard, request(session->clientIp(), boost::json::parse(kREQUEST).as_object()))
.WillOnce(testing::Return(true));
EXPECT_CALL(*rpcEngine, buildResponse(testing::_))
@@ -931,8 +931,8 @@ TEST_F(WebRPCServerHandlerTest, WsTooBusy)
"type": "response"
})JSON";
EXPECT_CALL(dosguard, isOk(session->clientIp)).WillOnce(testing::Return(true));
EXPECT_CALL(dosguard, request(session->clientIp, boost::json::parse(kREQUEST).as_object()))
EXPECT_CALL(dosguard, isOk(session->clientIp())).WillOnce(testing::Return(true));
EXPECT_CALL(dosguard, request(session->clientIp(), boost::json::parse(kREQUEST).as_object()))
.WillOnce(testing::Return(true));
EXPECT_CALL(*localRpcEngine, notifyTooBusy).Times(1);
@@ -962,8 +962,8 @@ TEST_F(WebRPCServerHandlerTest, HTTPTooBusy)
"type": "response"
})JSON";
EXPECT_CALL(dosguard, isOk(session->clientIp)).WillOnce(testing::Return(true));
EXPECT_CALL(dosguard, request(session->clientIp, boost::json::parse(kREQUEST).as_object()))
EXPECT_CALL(dosguard, isOk(session->clientIp())).WillOnce(testing::Return(true));
EXPECT_CALL(dosguard, request(session->clientIp(), boost::json::parse(kREQUEST).as_object()))
.WillOnce(testing::Return(true));
EXPECT_CALL(*localRpcEngine, notifyTooBusy).Times(1);
@@ -978,7 +978,7 @@ TEST_F(WebRPCServerHandlerTest, HTTPRequestNotJson)
static constexpr auto kREQUEST = "not json";
static constexpr auto kRESPONSE_PREFIX = "Unable to parse JSON from the request";
EXPECT_CALL(dosguard, isOk(session->clientIp)).WillOnce(testing::Return(true));
EXPECT_CALL(dosguard, isOk(session->clientIp())).WillOnce(testing::Return(true));
EXPECT_CALL(*rpcEngine, notifyBadSyntax).Times(1);
@@ -1000,7 +1000,7 @@ TEST_F(WebRPCServerHandlerTest, WsRequestNotJson)
"type": "response"
})JSON";
EXPECT_CALL(dosguard, isOk(session->clientIp)).WillOnce(testing::Return(true));
EXPECT_CALL(dosguard, isOk(session->clientIp())).WillOnce(testing::Return(true));
EXPECT_CALL(*rpcEngine, notifyBadSyntax).Times(1);
@@ -1058,8 +1058,8 @@ TEST_P(WebRPCServerHandlerInvalidAPIVersionParamTest, HTTPInvalidAPIVersion)
backend_->setRange(kMIN_SEQ, kMAX_SEQ);
EXPECT_CALL(dosguard, isOk(session->clientIp)).WillOnce(testing::Return(true));
EXPECT_CALL(dosguard, request(session->clientIp, boost::json::parse(request).as_object()))
EXPECT_CALL(dosguard, isOk(session->clientIp())).WillOnce(testing::Return(true));
EXPECT_CALL(dosguard, request(session->clientIp(), boost::json::parse(request).as_object()))
.WillOnce(testing::Return(true));
EXPECT_CALL(*rpcEngine, notifyBadSyntax).Times(1);
@@ -1082,8 +1082,8 @@ TEST_P(WebRPCServerHandlerInvalidAPIVersionParamTest, WSInvalidAPIVersion)
backend_->setRange(kMIN_SEQ, kMAX_SEQ);
EXPECT_CALL(dosguard, isOk(session->clientIp)).WillOnce(testing::Return(true));
EXPECT_CALL(dosguard, request(session->clientIp, boost::json::parse(request).as_object()))
EXPECT_CALL(dosguard, isOk(session->clientIp())).WillOnce(testing::Return(true));
EXPECT_CALL(dosguard, request(session->clientIp(), boost::json::parse(request).as_object()))
.WillOnce(testing::Return(true));
EXPECT_CALL(*rpcEngine, notifyBadSyntax).Times(1);

View File

@@ -124,6 +124,8 @@ getParseServerConfig(boost::json::value val)
{"server.port", ConfigValue{ConfigType::Integer}},
{"server.admin_password", ConfigValue{ConfigType::String}.optional()},
{"server.local_admin", ConfigValue{ConfigType::Boolean}.optional()},
{"server.proxy.ips.[]", Array{ConfigValue{ConfigType::String}}},
{"server.proxy.tokens.[]", Array{ConfigValue{ConfigType::String}}},
{"server.ws_max_sending_queue_size", ConfigValue{ConfigType::Integer}.defaultValue(1500)},
{"log.tag_style", ConfigValue{ConfigType::String}.defaultValue("uint")},
{"dos_guard.max_fetches", ConfigValue{ConfigType::Integer}},
@@ -519,6 +521,8 @@ getParseAdminServerConfig(boost::json::value val)
{"server.admin_password", ConfigValue{ConfigType::String}.optional()},
{"server.local_admin", ConfigValue{ConfigType::Boolean}.optional()},
{"server.processing_policy", ConfigValue{ConfigType::String}.defaultValue("parallel")},
{"server.proxy.ips.[]", Array{ConfigValue{ConfigType::String}}},
{"server.proxy.tokens.[]", Array{ConfigValue{ConfigType::String}}},
{"server.parallel_requests_limit", ConfigValue{ConfigType::Integer}.optional()},
{"server.ws_max_sending_queue_size", ConfigValue{ConfigType::Integer}.defaultValue(1500)},
{"ssl_cert_file", ConfigValue{ConfigType::String}.optional()},

View File

@@ -26,11 +26,13 @@
#include "util/Taggable.hpp"
#include "util/TestHttpClient.hpp"
#include "util/TestWebSocketClient.hpp"
#include "util/config/Array.hpp"
#include "util/config/ConfigConstraints.hpp"
#include "util/config/ConfigDefinition.hpp"
#include "util/config/ConfigFileJson.hpp"
#include "util/config/ConfigValue.hpp"
#include "util/config/Types.hpp"
#include "web/ProxyIpResolver.hpp"
#include "web/SubscriptionContextInterface.hpp"
#include "web/ng/Connection.hpp"
#include "web/ng/ProcessingPolicy.hpp"
@@ -58,6 +60,7 @@
#include <optional>
#include <ranges>
#include <string>
#include <unordered_set>
using namespace web::ng;
using namespace util::config;
@@ -85,6 +88,8 @@ TEST_P(MakeServerTest, Make)
{"server.ip", ConfigValue{ConfigType::String}.optional()},
{"server.port", ConfigValue{ConfigType::Integer}.optional()},
{"server.processing_policy", ConfigValue{ConfigType::String}.defaultValue("parallel")},
{"server.proxy.ips.[]", Array{ConfigValue{ConfigType::String}}},
{"server.proxy.tokens.[]", Array{ConfigValue{ConfigType::String}}},
{"server.parallel_requests_limit", ConfigValue{ConfigType::Integer}.optional()},
{"server.ws_max_sending_queue_size", ConfigValue{ConfigType::Integer}.defaultValue(1500)},
{"log.tag_style", ConfigValue{ConfigType::String}.defaultValue("uint")},
@@ -95,8 +100,13 @@ TEST_P(MakeServerTest, Make)
auto const errors = config.parse(json);
ASSERT_TRUE(!errors.has_value());
auto const expectedServer =
makeServer(config, [](auto&&) -> std::expected<void, Response> { return {}; }, [](auto&&) {}, ioContext_);
auto const expectedServer = makeServer(
config,
[](auto&&) -> std::expected<void, Response> { return {}; },
[](auto&&, auto&&) {},
[](auto&&) {},
ioContext_
);
EXPECT_EQ(expectedServer.has_value(), GetParam().expectSuccess);
}
@@ -173,6 +183,8 @@ protected:
{"server.admin_password", ConfigValue{ConfigType::String}.optional()},
{"server.local_admin", ConfigValue{ConfigType::Boolean}.optional()},
{"server.parallel_requests_limit", ConfigValue{ConfigType::Integer}.optional()},
{"server.proxy.ips.[]", Array{ConfigValue{ConfigType::String}}},
{"server.proxy.tokens.[]", Array{ConfigValue{ConfigType::String}}},
{"server.ws_max_sending_queue_size", ConfigValue{ConfigType::Integer}.defaultValue(1500)},
{"log.tag_style", ConfigValue{ConfigType::String}.defaultValue("uint")},
{"ssl_key_file", ConfigValue{ConfigType::String}.optional()},
@@ -180,7 +192,8 @@ protected:
};
Server::OnConnectCheck emptyOnConnectCheck_ = [](auto&&) -> std::expected<void, Response> { return {}; };
std::expected<Server, std::string> server_ = makeServer(config_, emptyOnConnectCheck_, [](auto&&) {}, ctx_);
std::expected<Server, std::string> server_ =
makeServer(config_, emptyOnConnectCheck_, [](auto&&, auto&&) {}, [](auto&&) {}, ctx_);
std::string requestMessage_ = "some request";
std::string const headerName_ = "Some-header";
@@ -210,9 +223,13 @@ TEST_F(ServerTest, BadEndpoint)
ProcessingPolicy::Sequential,
std::nullopt,
tagDecoratorFactory,
web::ProxyIpResolver{{}, {}},
std::nullopt,
emptyOnConnectCheck_,
[](auto&&) {}
Server::Hooks{
.onConnectCheck = emptyOnConnectCheck_,
.onIpChangeHook = [](auto&&, auto&&) {},
.onDisconnectHook = [](auto&&) {}
}
};
auto maybeError = server.run();
@@ -274,9 +291,13 @@ TEST_F(ServerHttpTest, OnConnectCheck)
ProcessingPolicy::Sequential,
std::nullopt,
tagDecoratorFactory,
web::ProxyIpResolver{{}, {}},
std::nullopt,
onConnectCheck.AsStdFunction(),
[](auto&&) {}
Server::Hooks{
.onConnectCheck = onConnectCheck.AsStdFunction(),
.onIpChangeHook = [](auto&&, auto&&) {},
.onDisconnectHook = [](auto&&) {}
}
};
HttpAsyncClient client{ctx_};
@@ -334,9 +355,13 @@ TEST_F(ServerHttpTest, OnConnectCheckFailed)
ProcessingPolicy::Sequential,
std::nullopt,
tagDecoratorFactory,
web::ProxyIpResolver{{}, {}},
std::nullopt,
onConnectCheck.AsStdFunction(),
[](auto&&) {}
Server::Hooks{
.onConnectCheck = onConnectCheck.AsStdFunction(),
.onIpChangeHook = [](auto&&, auto&&) {},
.onDisconnectHook = [](auto&&) {}
}
};
HttpAsyncClient client{ctx_};
@@ -393,9 +418,13 @@ TEST_F(ServerHttpTest, OnDisconnectHook)
ProcessingPolicy::Sequential,
std::nullopt,
tagDecoratorFactory,
web::ProxyIpResolver{{}, {}},
std::nullopt,
emptyOnConnectCheck_,
onDisconnectHookMock.AsStdFunction()
Server::Hooks{
.onConnectCheck = emptyOnConnectCheck_,
.onIpChangeHook = [](auto&&, auto&&) {},
.onDisconnectHook = onDisconnectHookMock.AsStdFunction()
}
};
HttpAsyncClient client{ctx_};

View File

@@ -25,6 +25,7 @@
#include "util/config/ConfigDefinition.hpp"
#include "util/config/ConfigValue.hpp"
#include "util/config/Types.hpp"
#include "web/ProxyIpResolver.hpp"
#include "web/SubscriptionContextInterface.hpp"
#include "web/ng/Connection.hpp"
#include "web/ng/Error.hpp"
@@ -40,11 +41,13 @@
#include <boost/asio/steady_timer.hpp>
#include <boost/beast/core/flat_buffer.hpp>
#include <boost/beast/http/error.hpp>
#include <boost/beast/http/field.hpp>
#include <boost/beast/http/message.hpp>
#include <boost/beast/http/status.hpp>
#include <boost/beast/http/string_body.hpp>
#include <boost/beast/http/verb.hpp>
#include <boost/beast/websocket/error.hpp>
#include <fmt/format.h>
#include <gmock/gmock.h>
#include <gtest/gtest.h>
@@ -60,7 +63,6 @@ using namespace web::ng::impl;
using namespace web::ng;
using namespace util;
using testing::Return;
namespace beast = boost::beast;
namespace http = boost::beast::http;
namespace websocket = boost::beast::websocket;
@@ -69,7 +71,15 @@ struct ConnectionHandlerTest : prometheus::WithPrometheus, SyncAsioContextTest {
: tagFactory{util::config::ClioConfigDefinition{
{"log.tag_style", config::ConfigValue{config::ConfigType::String}.defaultValue("uint")}
}}
, connectionHandler{policy, maxParallelConnections, tagFactory, std::nullopt, onDisconnectMock.AsStdFunction()}
, connectionHandler{
policy,
maxParallelConnections,
tagFactory,
std::nullopt,
proxyIpResolver,
onDisconnectMock.AsStdFunction(),
onIpChangeMock.AsStdFunction()
}
{
}
@@ -98,6 +108,12 @@ struct ConnectionHandlerTest : prometheus::WithPrometheus, SyncAsioContextTest {
return Request{std::forward<Args>(args)...};
}
std::string const clientIp = "1.2.3.4";
std::string const proxyIp = "5.6.7.8";
std::string const proxyToken = "some_proxy_token";
web::ProxyIpResolver proxyIpResolver{{proxyIp}, {proxyToken}};
testing::StrictMock<testing::MockFunction<void(std::string const&, std::string const&)>> onIpChangeMock;
testing::StrictMock<testing::MockFunction<void(Connection const&)>> onDisconnectMock;
util::TagDecoratorFactory tagFactory;
ConnectionHandler connectionHandler;
@@ -106,9 +122,9 @@ struct ConnectionHandlerTest : prometheus::WithPrometheus, SyncAsioContextTest {
{"log.tag_style", config::ConfigValue{config::ConfigType::String}.defaultValue("uint")}
}};
StrictMockHttpConnectionPtr mockHttpConnection =
std::make_unique<StrictMockHttpConnection>("1.2.3.4", beast::flat_buffer{}, tagDecoratorFactory);
std::make_unique<StrictMockHttpConnection>(clientIp, boost::beast::flat_buffer{}, tagDecoratorFactory);
StrictMockWsConnectionPtr mockWsConnection =
std::make_unique<StrictMockWsConnection>("1.2.3.4", beast::flat_buffer{}, tagDecoratorFactory);
std::make_unique<StrictMockWsConnection>(clientIp, boost::beast::flat_buffer{}, tagDecoratorFactory);
Request::HttpHeaders headers;
};
@@ -452,6 +468,50 @@ TEST_F(ConnectionHandlerSequentialProcessingTest, Receive_Handle_SendError)
});
}
TEST_F(ConnectionHandlerSequentialProcessingTest, OnIpChangeHookCalledWhenSentFromProxy)
{
std::string const target = "/some/target";
testing::StrictMock<testing::MockFunction<
Response(Request const&, ConnectionMetadata const&, web::SubscriptionContextPtr, boost::asio::yield_context)>>
getHandlerMock;
std::string const requestMessage = "some message";
std::string const responseMessage = "some response";
connectionHandler.onGet(target, getHandlerMock.AsStdFunction());
StrictMockHttpConnectionPtr mockHttpConnectionFromProxy =
std::make_unique<StrictMockHttpConnection>(proxyIp, boost::beast::flat_buffer{}, tagDecoratorFactory);
auto request = http::request<http::string_body>{http::verb::get, target, 11, requestMessage};
request.set(http::field::forwarded, fmt::format("for={}", clientIp));
EXPECT_CALL(*mockHttpConnectionFromProxy, wasUpgraded).WillOnce(Return(false));
EXPECT_CALL(*mockHttpConnectionFromProxy, receive).WillOnce(Return(makeRequest(request)));
EXPECT_CALL(onIpChangeMock, Call(proxyIp, clientIp));
EXPECT_CALL(getHandlerMock, Call).WillOnce([&](Request const& request, auto&&, auto&&, auto&&) {
EXPECT_EQ(request.message(), requestMessage);
return Response(http::status::ok, responseMessage, request);
});
EXPECT_CALL(*mockHttpConnectionFromProxy, send).WillOnce([&responseMessage](Response response, auto&&) {
EXPECT_EQ(response.message(), responseMessage);
return makeError(http::error::end_of_stream).error();
});
EXPECT_CALL(onDisconnectMock, Call)
.WillOnce([this, connectionPtr = mockHttpConnectionFromProxy.get()](Connection const& c) {
EXPECT_EQ(&c, connectionPtr);
EXPECT_EQ(c.ip(), clientIp);
});
runSpawn([this, c = std::move(mockHttpConnectionFromProxy)](boost::asio::yield_context yield) mutable {
connectionHandler.processConnection(std::move(c), yield);
});
}
TEST_F(ConnectionHandlerSequentialProcessingTest, Stop)
{
testing::StrictMock<testing::MockFunction<
@@ -613,6 +673,48 @@ TEST_F(ConnectionHandlerParallelProcessingTest, Receive_Handle_Send)
});
}
TEST_F(ConnectionHandlerParallelProcessingTest, OnIpChangeHookCalledWhenSentFromProxy)
{
testing::StrictMock<testing::MockFunction<
Response(Request const&, ConnectionMetadata const&, web::SubscriptionContextPtr, boost::asio::yield_context)>>
wsHandlerMock;
connectionHandler.onWs(wsHandlerMock.AsStdFunction());
StrictMockWsConnectionPtr mockWsConnectionFromProxy =
std::make_unique<StrictMockWsConnection>(proxyIp, boost::beast::flat_buffer{}, tagDecoratorFactory);
headers.set(http::field::forwarded, fmt::format("for={}", clientIp));
std::string const requestMessage = "some message";
std::string const responseMessage = "some response";
EXPECT_CALL(*mockWsConnectionFromProxy, wasUpgraded).WillOnce(Return(true));
EXPECT_CALL(*mockWsConnectionFromProxy, receive)
.WillOnce(Return(makeRequest(requestMessage, headers)))
.WillOnce(Return(makeError(websocket::error::closed)));
EXPECT_CALL(onIpChangeMock, Call(proxyIp, clientIp));
EXPECT_CALL(wsHandlerMock, Call).WillOnce([&](Request const& request, auto&&, auto&&, auto&&) {
EXPECT_EQ(request.message(), requestMessage);
return Response(http::status::ok, responseMessage, request);
});
EXPECT_CALL(*mockWsConnectionFromProxy, send).WillOnce([&responseMessage](Response response, auto&&) {
EXPECT_EQ(response.message(), responseMessage);
return std::nullopt;
});
EXPECT_CALL(onDisconnectMock, Call)
.WillOnce([this, connectionPtr = mockWsConnectionFromProxy.get()](Connection const& c) {
EXPECT_EQ(&c, connectionPtr);
EXPECT_EQ(c.ip(), clientIp);
});
runSpawn([this, c = std::move(mockWsConnectionFromProxy)](boost::asio::yield_context yield) mutable {
connectionHandler.processConnection(std::move(c), yield);
});
}
TEST_F(ConnectionHandlerParallelProcessingTest, Receive_Handle_Send_Loop)
{
testing::StrictMock<testing::MockFunction<

View File

@@ -38,6 +38,7 @@
#include <boost/beast/http/status.hpp>
#include <boost/beast/http/string_body.hpp>
#include <boost/beast/http/verb.hpp>
#include <fmt/format.h>
#include <gmock/gmock.h>
#include <gtest/gtest.h>

View File

@@ -39,7 +39,9 @@
#include <boost/asio/spawn.hpp>
#include <boost/asio/steady_timer.hpp>
#include <boost/beast/core/flat_buffer.hpp>
#include <boost/beast/http/field.hpp>
#include <boost/beast/http/status.hpp>
#include <fmt/format.h>
#include <gmock/gmock.h>
#include <gtest/gtest.h>
@@ -65,7 +67,10 @@ struct WebWsConnectionTests : SyncAsioContextTest {
auto ip = expectedSocket->remote_endpoint().address().to_string();
PlainHttpConnection httpConnection{
std::move(expectedSocket).value(), std::move(ip), boost::beast::flat_buffer{}, tagDecoratorFactory_
std::move(expectedSocket).value(),
std::move(ip),
boost::beast::flat_buffer{},
tagDecoratorFactory_,
};
auto expectedTrue = httpConnection.isUpgradeRequested(yield);