feat: Proxy support (#2490)

Add client IP resolving support in case when there is a proxy in front
of Clio.
This commit is contained in:
Sergey Kuznetsov
2025-09-03 15:22:47 +01:00
committed by GitHub
parent 0a2930d861
commit 3a667f558c
39 changed files with 1042 additions and 125 deletions

View File

@@ -329,6 +329,22 @@ This document provides a list of all available Clio configuration properties in
- **Constraints**: The minimum value is `1`. The maximum value is `4294967295`. - **Constraints**: The minimum value is `1`. The maximum value is `4294967295`.
- **Description**: Maximum queue size for sending subscription data to clients. This queue buffers data when a client is slow to receive it, ensuring delivery once the client is ready. - **Description**: Maximum queue size for sending subscription data to clients. This queue buffers data when a client is slow to receive it, ensuring delivery once the client is ready.
### server.proxy.ips.[]
- **Required**: True
- **Type**: string
- **Default value**: None
- **Constraints**: None
- **Description**: List of proxy ip addresses. When Clio receives a request from proxy it will use `Forwarded` value (if any) as client ip. When this option is used together with `server.proxy.tokens` Clio will identify proxy by ip or by token.
### server.proxy.tokens.[]
- **Required**: True
- **Type**: string
- **Default value**: None
- **Constraints**: None
- **Description**: List of tokens in identifying request as a request from proxy. Token should be provided in `X-Proxy-Token` header, e.g. `X-Proxy-Token: <very_secret_token>'. When Clio receives a request from proxy it will use 'Forwarded` value (if any) to get client ip. When this option is used together with 'server.proxy.ips' Clio will identify proxy by ip or by token.
### prometheus.enabled ### prometheus.enabled
- **Required**: True - **Required**: True

View File

@@ -76,7 +76,11 @@
"parallel_requests_limit": 10, // Optional parameter, used only if "processing_strategy" is "parallel". It limits the number of requests for one client connection processed in parallel. Infinite if not specified. "parallel_requests_limit": 10, // Optional parameter, used only if "processing_strategy" is "parallel". It limits the number of requests for one client connection processed in parallel. Infinite if not specified.
// Max number of responses to queue up before sent successfully. If a client's waiting queue is too long, the server will close the connection. // Max number of responses to queue up before sent successfully. If a client's waiting queue is too long, the server will close the connection.
"ws_max_sending_queue_size": 1500, "ws_max_sending_queue_size": 1500,
"__ng_web_server": false // Use ng web server. This is a temporary setting which will be deleted after switching to ng web server "__ng_web_server": false, // Use ng web server. This is a temporary setting which will be deleted after switching to ng web server
"proxy": {
"ips": [],
"tokens": []
}
}, },
// Time in seconds for graceful shutdown. Defaults to 10 seconds. Not fully implemented yet. // Time in seconds for graceful shutdown. Defaults to 10 seconds. Not fully implemented yet.
"graceful_period": 10.0, "graceful_period": 10.0,

View File

@@ -178,7 +178,9 @@ ClioApplication::run(bool const useNgWebServer)
} }
auto const adminVerifier = std::move(expectedAdminVerifier).value(); auto const adminVerifier = std::move(expectedAdminVerifier).value();
auto httpServer = web::ng::makeServer(config_, OnConnectCheck{dosGuard}, DisconnectHook{dosGuard}, ioc); auto httpServer = web::ng::makeServer(
config_, OnConnectCheck{dosGuard}, IpChangeHook{dosGuard}, DisconnectHook{dosGuard}, ioc
);
if (not httpServer.has_value()) { if (not httpServer.has_value()) {
LOG(util::LogService::error()) << "Error creating web server: " << httpServer.error(); LOG(util::LogService::error()) << "Error creating web server: " << httpServer.error();

View File

@@ -33,6 +33,7 @@
#include <memory> #include <memory>
#include <optional> #include <optional>
#include <string>
#include <utility> #include <utility>
namespace app { namespace app {
@@ -54,6 +55,17 @@ OnConnectCheck::operator()(web::ng::Connection const& connection)
return {}; return {};
} }
IpChangeHook::IpChangeHook(web::dosguard::DOSGuardInterface& dosguard) : dosguard_(dosguard)
{
}
void
IpChangeHook::operator()(std::string const& oldIp, std::string const& newIp)
{
dosguard_.get().decrement(oldIp);
dosguard_.get().increment(newIp);
}
DisconnectHook::DisconnectHook(web::dosguard::DOSGuardInterface& dosguard) : dosguard_{dosguard} DisconnectHook::DisconnectHook(web::dosguard::DOSGuardInterface& dosguard) : dosguard_{dosguard}
{ {
} }

View File

@@ -36,6 +36,7 @@
#include <exception> #include <exception>
#include <functional> #include <functional>
#include <memory> #include <memory>
#include <string>
#include <utility> #include <utility>
namespace app { namespace app {
@@ -64,6 +65,31 @@ public:
operator()(web::ng::Connection const& connection); operator()(web::ng::Connection const& connection);
}; };
/**
* @brief A function object that is called when the IP of a connection changes (usually if proxy detected).
* This is used to update the DOS guard.
*/
class IpChangeHook {
std::reference_wrapper<web::dosguard::DOSGuardInterface> dosguard_;
public:
/**
* @brief Construct a new IpChangeHook object.
*
* @param dosguard The DOS guard to use.
*/
IpChangeHook(web::dosguard::DOSGuardInterface& dosguard);
/**
* @brief The call of the function object.
*
* @param oldIp The old IP of the connection.
* @param newIp The new IP of the connection.
*/
void
operator()(std::string const& oldIp, std::string const& newIp);
};
/** /**
* @brief A function object to be called when a connection is disconnected. * @brief A function object to be called when a connection is disconnected.
*/ */

View File

@@ -24,6 +24,7 @@ target_sources(
requests/WsConnection.cpp requests/WsConnection.cpp
requests/impl/SslContext.cpp requests/impl/SslContext.cpp
ResponseExpirationCache.cpp ResponseExpirationCache.cpp
Shasum.cpp
SignalsHandler.cpp SignalsHandler.cpp
StopHelper.cpp StopHelper.cpp
StringHash.cpp StringHash.cpp

48
src/util/Shasum.cpp Normal file
View File

@@ -0,0 +1,48 @@
//------------------------------------------------------------------------------
/*
This file is part of clio: https://github.com/XRPLF/clio
Copyright (c) 2025, the clio developers.
Permission to use, copy, modify, and distribute this software for any
purpose with or without fee is hereby granted, provided that the above
copyright notice and this permission notice appear in all copies.
THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
*/
//==============================================================================
#include "util/Shasum.hpp"
#include <xrpl/basics/base_uint.h>
#include <xrpl/protocol/digest.h>
#include <cstring>
#include <string>
#include <string_view>
namespace util {
ripple::uint256
sha256sum(std::string_view s)
{
ripple::sha256_hasher hasher;
hasher(s.data(), s.size());
auto const hashData = static_cast<ripple::sha256_hasher::result_type>(hasher);
ripple::uint256 sha256;
std::memcpy(sha256.data(), hashData.data(), hashData.size());
return sha256;
}
std::string
sha256sumString(std::string_view s)
{
return ripple::to_string(sha256sum(s));
}
} // namespace util

46
src/util/Shasum.hpp Normal file
View File

@@ -0,0 +1,46 @@
//------------------------------------------------------------------------------
/*
This file is part of clio: https://github.com/XRPLF/clio
Copyright (c) 2025, the clio developers.
Permission to use, copy, modify, and distribute this software for any
purpose with or without fee is hereby granted, provided that the above
copyright notice and this permission notice appear in all copies.
THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
*/
//==============================================================================
#pragma once
#include <xrpl/basics/base_uint.h>
#include <string>
#include <string_view>
namespace util {
/**
* @brief Calculates the SHA256 sum of a string.
*
* @param s The string to hash.
* @return The SHA256 sum as a ripple::uint256.
*/
ripple::uint256
sha256sum(std::string_view s);
/**
* @brief Calculates the SHA256 sum of a string and returns it as a hex string.
*
* @param s The string to hash.
* @return The SHA256 sum as a hex string.
*/
std::string
sha256sumString(std::string_view s);
} // namespace util

View File

@@ -337,6 +337,8 @@ getClioConfig()
{"server.ws_max_sending_queue_size", {"server.ws_max_sending_queue_size",
ConfigValue{ConfigType::Integer}.defaultValue(1500).withConstraint(gValidateUint32)}, ConfigValue{ConfigType::Integer}.defaultValue(1500).withConstraint(gValidateUint32)},
{"server.__ng_web_server", ConfigValue{ConfigType::Boolean}.defaultValue(false)}, {"server.__ng_web_server", ConfigValue{ConfigType::Boolean}.defaultValue(false)},
{"server.proxy.ips.[]", Array{ConfigValue{ConfigType::String}}},
{"server.proxy.tokens.[]", Array{ConfigValue{ConfigType::String}}},
{"prometheus.enabled", ConfigValue{ConfigType::Boolean}.defaultValue(true)}, {"prometheus.enabled", ConfigValue{ConfigType::Boolean}.defaultValue(true)},
{"prometheus.compress_reply", ConfigValue{ConfigType::Boolean}.defaultValue(true)}, {"prometheus.compress_reply", ConfigValue{ConfigType::Boolean}.defaultValue(true)},

View File

@@ -236,6 +236,16 @@ This document provides a list of all available Clio configuration properties in
KV{.key = "server.ws_max_sending_queue_size", KV{.key = "server.ws_max_sending_queue_size",
.value = "Maximum queue size for sending subscription data to clients. This queue buffers data when a " .value = "Maximum queue size for sending subscription data to clients. This queue buffers data when a "
"client is slow to receive it, ensuring delivery once the client is ready."}, "client is slow to receive it, ensuring delivery once the client is ready."},
KV{.key = "server.proxy.ips.[]",
.value = "List of proxy ip addresses. When Clio receives a request from proxy it will use "
"`Forwarded` value (if any) as client ip. When this option is used together with "
"`server.proxy.tokens` Clio will identify proxy by ip or by token."},
KV{.key = "server.proxy.tokens.[]",
.value = "List of tokens in identifying request as a request from proxy. Token should be provided in "
"`X-Proxy-Token` header, e.g. "
"`X-Proxy-Token: <very_secret_token>'. When Clio receives a request from proxy "
"it will use 'Forwarded` value (if any) to get client ip. When this option is used together with "
"'server.proxy.ips' Clio will identify proxy by ip or by token."},
KV{.key = "prometheus.enabled", .value = "Enables or disables Prometheus metrics."}, KV{.key = "prometheus.enabled", .value = "Enables or disables Prometheus metrics."},
KV{.key = "prometheus.compress_reply", .value = "Enables or disables compression of Prometheus responses."}, KV{.key = "prometheus.compress_reply", .value = "Enables or disables compression of Prometheus responses."},
KV{.key = "io_threads", KV{.key = "io_threads",

View File

@@ -20,6 +20,7 @@
#include "web/AdminVerificationStrategy.hpp" #include "web/AdminVerificationStrategy.hpp"
#include "util/JsonUtils.hpp" #include "util/JsonUtils.hpp"
#include "util/Shasum.hpp"
#include "util/config/ConfigDefinition.hpp" #include "util/config/ConfigDefinition.hpp"
#include <boost/beast/http/field.hpp> #include <boost/beast/http/field.hpp>
@@ -42,15 +43,8 @@ IPAdminVerificationStrategy::isAdmin(RequestHeader const&, std::string_view ip)
} }
PasswordAdminVerificationStrategy::PasswordAdminVerificationStrategy(std::string const& password) PasswordAdminVerificationStrategy::PasswordAdminVerificationStrategy(std::string const& password)
: passwordSha256_(util::toUpper(util::sha256sumString(password)))
{ {
ripple::sha256_hasher hasher;
hasher(password.data(), password.size());
auto const d = static_cast<ripple::sha256_hasher::result_type>(hasher);
ripple::uint256 sha256;
std::memcpy(sha256.data(), d.data(), d.size());
passwordSha256_ = ripple::to_string(sha256);
// make sure it's uppercase
passwordSha256_ = util::toUpper(std::move(passwordSha256_));
} }
bool bool

View File

@@ -15,6 +15,7 @@ target_sources(
ng/Response.cpp ng/Response.cpp
ng/Server.cpp ng/Server.cpp
ng/SubscriptionContext.cpp ng/SubscriptionContext.cpp
ProxyIpResolver.cpp
Resolver.cpp Resolver.cpp
SubscriptionContext.cpp SubscriptionContext.cpp
) )

View File

@@ -22,6 +22,7 @@
#include "util/Taggable.hpp" #include "util/Taggable.hpp"
#include "web/AdminVerificationStrategy.hpp" #include "web/AdminVerificationStrategy.hpp"
#include "web/PlainWsSession.hpp" #include "web/PlainWsSession.hpp"
#include "web/ProxyIpResolver.hpp"
#include "web/dosguard/DOSGuardInterface.hpp" #include "web/dosguard/DOSGuardInterface.hpp"
#include "web/impl/HttpBase.hpp" #include "web/impl/HttpBase.hpp"
#include "web/interface/Concepts.hpp" #include "web/interface/Concepts.hpp"
@@ -64,6 +65,7 @@ public:
* @param socket The socket. Ownership is transferred to HttpSession * @param socket The socket. Ownership is transferred to HttpSession
* @param ip Client's IP address * @param ip Client's IP address
* @param adminVerification The admin verification strategy to use * @param adminVerification The admin verification strategy to use
* @param proxyIpResolver The client ip resolver if a request was forwarded by a proxy
* @param tagFactory A factory that is used to generate tags to track requests and sessions * @param tagFactory A factory that is used to generate tags to track requests and sessions
* @param dosGuard The denial of service guard to use * @param dosGuard The denial of service guard to use
* @param handler The server handler to use * @param handler The server handler to use
@@ -74,6 +76,7 @@ public:
tcp::socket&& socket, tcp::socket&& socket,
std::string const& ip, std::string const& ip,
std::shared_ptr<AdminVerificationStrategy> const& adminVerification, std::shared_ptr<AdminVerificationStrategy> const& adminVerification,
std::shared_ptr<ProxyIpResolver> proxyIpResolver,
std::reference_wrapper<util::TagDecoratorFactory const> tagFactory, std::reference_wrapper<util::TagDecoratorFactory const> tagFactory,
std::reference_wrapper<dosguard::DOSGuardInterface> dosGuard, std::reference_wrapper<dosguard::DOSGuardInterface> dosGuard,
std::shared_ptr<HandlerType> const& handler, std::shared_ptr<HandlerType> const& handler,
@@ -84,6 +87,7 @@ public:
ip, ip,
tagFactory, tagFactory,
adminVerification, adminVerification,
std::move(proxyIpResolver),
dosGuard, dosGuard,
handler, handler,
std::move(buffer) std::move(buffer)
@@ -129,7 +133,7 @@ public:
{ {
std::make_shared<WsUpgrader<HandlerType>>( std::make_shared<WsUpgrader<HandlerType>>(
std::move(stream_), std::move(stream_),
this->clientIp, this->clientIp_,
tagFactory_, tagFactory_,
this->dosGuard_, this->dosGuard_,
this->handler_, this->handler_,

119
src/web/ProxyIpResolver.cpp Normal file
View File

@@ -0,0 +1,119 @@
/*
This file is part of clio: https://github.com/XRPLF/clio
Copyright (c) 2025, the clio developers.
Permission to use, copy, modify, and distribute this software for any
purpose with or without fee is hereby granted, provided that the above
copyright notice and this permission notice appear in all copies.
THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
*/
//==============================================================================
#include "web/ProxyIpResolver.hpp"
#include "util/JsonUtils.hpp"
#include "util/Shasum.hpp"
#include "util/config/ArrayView.hpp"
#include "util/config/ConfigDefinition.hpp"
#include "util/config/ValueView.hpp"
#include <boost/beast/http/field.hpp>
#include <boost/beast/http/message.hpp>
#include <boost/beast/http/string_body.hpp>
#include <xrpl/basics/base_uint.h>
#include <optional>
#include <string>
#include <string_view>
#include <unordered_set>
#include <utility>
namespace web {
ProxyIpResolver::ProxyIpResolver(std::unordered_set<std::string> proxyIps, std::unordered_set<std::string> proxyTokens)
: proxyIps_(std::move(proxyIps))
{
proxyTokens_.reserve(proxyTokens.size());
for (auto const& t : proxyTokens) {
proxyTokens_.insert(util::sha256sum(t));
}
}
ProxyIpResolver
ProxyIpResolver::fromConfig(util::config::ClioConfigDefinition const& config)
{
using util::config::ValueView;
std::unordered_set<std::string> ips;
auto const ipsFromConfig = config.getArray("server.proxy.ips");
for (auto it = ipsFromConfig.begin<ValueView>(); it != ipsFromConfig.end<ValueView>(); ++it) {
ips.insert((*it).asString());
}
std::unordered_set<std::string> tokens;
auto const tokensFromConfig = config.getArray("server.proxy.tokens");
for (auto it = tokensFromConfig.begin<ValueView>(); it != tokensFromConfig.end<ValueView>(); ++it) {
tokens.insert((*it).asString());
}
return ProxyIpResolver{std::move(ips), std::move(tokens)};
}
std::string
ProxyIpResolver::resolveClientIp(std::string const& connectionIp, HttpHeaders const& headers) const
{
if (proxyIps_.contains(connectionIp)) {
return extractClientIp(headers).value_or(connectionIp);
}
if (auto it = headers.find(kPROXY_TOKEN_HEADER); it != headers.end()) {
auto const tokenHash = util::sha256sum(it->value());
if (proxyTokens_.contains(tokenHash)) {
return extractClientIp(headers).value_or(connectionIp);
}
}
return connectionIp;
}
std::optional<std::string>
ProxyIpResolver::extractClientIp(HttpHeaders const& headers)
{
auto const it = headers.find(boost::beast::http::field::forwarded);
if (it == headers.end()) {
return std::nullopt;
}
// Forwarded header is case insensitive:
// https://developer.mozilla.org/en-US/docs/Web/HTTP/Reference/Headers/Forwarded#using_the_forwarded_header
auto const headerValue = util::toLower(it->value());
static constexpr std::string_view kFOR_PREFIX = "for=";
auto const startPos = headerValue.find(kFOR_PREFIX);
if (startPos == std::string::npos) {
return std::nullopt;
}
auto value = it->value().substr(startPos + kFOR_PREFIX.size());
static constexpr char kDELIMITER = ';';
auto const endPos = value.find(kDELIMITER);
auto const ip = value.substr(0, endPos);
static constexpr auto kMIN_IP_LENGTH = 7; // minimum 3 dots + 4 digits
if (ip.size() < kMIN_IP_LENGTH) {
return std::nullopt;
}
if (ip.starts_with('"')) {
return ip.substr(1, ip.size() - 2);
}
return ip;
}
} // namespace web

101
src/web/ProxyIpResolver.hpp Normal file
View File

@@ -0,0 +1,101 @@
//------------------------------------------------------------------------------
/*
This file is part of clio: https://github.com/XRPLF/clio
Copyright (c) 2025, the clio developers.
Permission to use, copy, modify, and distribute this software for any
purpose with or without fee is hereby granted, provided that the above
copyright notice and this permission notice appear in all copies.
THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
*/
//==============================================================================
#pragma once
#include "util/config/ConfigDefinition.hpp"
#include "util/config/ValueView.hpp"
#include <boost/beast/http/field.hpp>
#include <boost/beast/http/message.hpp>
#include <boost/beast/http/string_body.hpp>
#include <xrpl/basics/base_uint.h>
#include <optional>
#include <string>
#include <string_view>
#include <unordered_set>
namespace web {
/**
* @brief Resolves the client's IP address, considering proxy servers.
*
* This class is designed to determine the original IP address of a client when the connection
* is forwarded through a proxy server. It uses a configurable list of trusted proxy IPs
* and proxy tokens to decide whether to trust the `Forwarded` HTTP header.
*/
class ProxyIpResolver {
std::unordered_set<std::string> proxyIps_;
// ripple::uint256 doesn't have hash implementation
std::unordered_set<ripple::uint256, ripple::uint256::hasher> proxyTokens_;
public:
/**
* @brief Constructs a ProxyIpResolver.
*
* @param proxyIps A set of trusted proxy IP addresses.
* @param proxyTokens A set of trusted proxy tokens. The tokens will be hashed with SHA-256.
*/
ProxyIpResolver(std::unordered_set<std::string> proxyIps, std::unordered_set<std::string> proxyTokens);
/**
* @brief Creates a ProxyIpResolver from a configuration.
*
* The configuration should contain `server.proxy.ips` and `server.proxy.tokens` arrays.
*
* @param config The Clio configuration.
* @return A new ProxyIpResolver instance.
*/
static ProxyIpResolver
fromConfig(util::config::ClioConfigDefinition const& config);
using HttpHeaders = boost::beast::http::request<boost::beast::http::string_body>::header_type;
static constexpr std::string_view kPROXY_TOKEN_HEADER = "X-Proxy-Token";
/**
* @brief Resolves the client's IP address from the connection IP and HTTP headers.
*
* If the connection IP is in the trusted proxy list, or if a valid proxy token is provided in the headers,
* this method will attempt to extract the client's IP from the `Forwarded` header.
* Otherwise, it returns the connection IP.
*
* @param connectionIp The IP address of the direct connection.
* @param headers The HTTP request headers.
* @return The resolved client IP address as a string.
*/
std::string
resolveClientIp(std::string const& connectionIp, HttpHeaders const& headers) const;
private:
/**
* @brief Extracts the client IP from the `Forwarded` HTTP header.
*
* The `Forwarded` header is expected to be in the format specified by RFC 7239.
* This function looks for the `for` parameter.
*
* @param headers The HTTP request headers.
* @return The client IP address as a string if found, otherwise std::nullopt.
*/
static std::optional<std::string>
extractClientIp(HttpHeaders const& headers);
};
} // namespace web

View File

@@ -107,7 +107,7 @@ public:
void void
operator()(std::string const& request, std::shared_ptr<web::ConnectionBase> const& connection) operator()(std::string const& request, std::shared_ptr<web::ConnectionBase> const& connection)
{ {
if (not dosguard_.get().isOk(connection->clientIp)) { if (not dosguard_.get().isOk(connection->clientIp())) {
connection->sendSlowDown(request); connection->sendSlowDown(request);
return; return;
} }
@@ -119,7 +119,7 @@ public:
if (not connection->upgraded and shouldReplaceParams(req)) if (not connection->upgraded and shouldReplaceParams(req))
req[JS(params)] = boost::json::array({boost::json::object{}}); req[JS(params)] = boost::json::array({boost::json::object{}});
if (not dosguard_.get().request(connection->clientIp, req)) { if (not dosguard_.get().request(connection->clientIp(), req)) {
connection->sendSlowDown(request); connection->sendSlowDown(request);
return; return;
} }
@@ -128,7 +128,7 @@ public:
[this, request = std::move(req), connection](boost::asio::yield_context yield) mutable { [this, request = std::move(req), connection](boost::asio::yield_context yield) mutable {
handleRequest(yield, std::move(request), connection); handleRequest(yield, std::move(request), connection);
}, },
connection->clientIp connection->clientIp()
)) { )) {
rpcEngine_->notifyTooBusy(); rpcEngine_->notifyTooBusy();
web::impl::ErrorHelper(connection).sendTooBusyError(); web::impl::ErrorHelper(connection).sendTooBusyError();
@@ -160,7 +160,7 @@ private:
{ {
LOG(log_.info()) << connection->tag() << (connection->upgraded ? "ws" : "http") LOG(log_.info()) << connection->tag() << (connection->upgraded ? "ws" : "http")
<< " received request from work queue: " << util::removeSecret(request) << " received request from work queue: " << util::removeSecret(request)
<< " ip = " << connection->clientIp; << " ip = " << connection->clientIp();
try { try {
auto const range = backend_->fetchLedgerRange(); auto const range = backend_->fetchLedgerRange();
@@ -180,7 +180,7 @@ private:
connection->makeSubscriptionContext(tagFactory_), connection->makeSubscriptionContext(tagFactory_),
tagFactory_.with(connection->tag()), tagFactory_.with(connection->tag()),
*range, *range,
connection->clientIp, connection->clientIp(),
std::cref(apiVersionParser_), std::cref(apiVersionParser_),
connection->isAdmin() connection->isAdmin()
); );
@@ -190,7 +190,7 @@ private:
request, request,
tagFactory_.with(connection->tag()), tagFactory_.with(connection->tag()),
*range, *range,
connection->clientIp, connection->clientIp(),
std::cref(apiVersionParser_), std::cref(apiVersionParser_),
connection->isAdmin() connection->isAdmin()
); );

View File

@@ -23,6 +23,7 @@
#include "util/log/Logger.hpp" #include "util/log/Logger.hpp"
#include "web/AdminVerificationStrategy.hpp" #include "web/AdminVerificationStrategy.hpp"
#include "web/HttpSession.hpp" #include "web/HttpSession.hpp"
#include "web/ProxyIpResolver.hpp"
#include "web/SslHttpSession.hpp" #include "web/SslHttpSession.hpp"
#include "web/dosguard/DOSGuardInterface.hpp" #include "web/dosguard/DOSGuardInterface.hpp"
#include "web/interface/Concepts.hpp" #include "web/interface/Concepts.hpp"
@@ -87,6 +88,7 @@ class Detector : public std::enable_shared_from_this<Detector<PlainSessionType,
boost::beast::flat_buffer buffer_; boost::beast::flat_buffer buffer_;
std::shared_ptr<AdminVerificationStrategy> const adminVerification_; std::shared_ptr<AdminVerificationStrategy> const adminVerification_;
std::uint32_t maxWsSendingQueueSize_; std::uint32_t maxWsSendingQueueSize_;
std::shared_ptr<ProxyIpResolver> proxyIpResolver_;
public: public:
/** /**
@@ -99,6 +101,7 @@ public:
* @param handler The server handler to use * @param handler The server handler to use
* @param adminVerification The admin verification strategy to use * @param adminVerification The admin verification strategy to use
* @param maxWsSendingQueueSize The maximum size of the sending queue for websocket * @param maxWsSendingQueueSize The maximum size of the sending queue for websocket
* @param proxyIpResolver The client ip resolver if a request was forwarded by a proxy
*/ */
Detector( Detector(
tcp::socket&& socket, tcp::socket&& socket,
@@ -107,7 +110,8 @@ public:
std::reference_wrapper<dosguard::DOSGuardInterface> dosGuard, std::reference_wrapper<dosguard::DOSGuardInterface> dosGuard,
std::shared_ptr<HandlerType> handler, std::shared_ptr<HandlerType> handler,
std::shared_ptr<AdminVerificationStrategy> adminVerification, std::shared_ptr<AdminVerificationStrategy> adminVerification,
std::uint32_t maxWsSendingQueueSize std::uint32_t maxWsSendingQueueSize,
std::shared_ptr<ProxyIpResolver> proxyIpResolver
) )
: stream_(std::move(socket)) : stream_(std::move(socket))
, ctx_(ctx) , ctx_(ctx)
@@ -116,6 +120,7 @@ public:
, handler_(std::move(handler)) , handler_(std::move(handler))
, adminVerification_(std::move(adminVerification)) , adminVerification_(std::move(adminVerification))
, maxWsSendingQueueSize_(maxWsSendingQueueSize) , maxWsSendingQueueSize_(maxWsSendingQueueSize)
, proxyIpResolver_(std::move(proxyIpResolver))
{ {
} }
@@ -169,6 +174,7 @@ public:
stream_.release_socket(), stream_.release_socket(),
ip, ip,
adminVerification_, adminVerification_,
proxyIpResolver_,
*ctx_, *ctx_,
tagFactory_, tagFactory_,
dosGuard_, dosGuard_,
@@ -184,6 +190,7 @@ public:
stream_.release_socket(), stream_.release_socket(),
ip, ip,
adminVerification_, adminVerification_,
proxyIpResolver_,
tagFactory_, tagFactory_,
dosGuard_, dosGuard_,
handler_, handler_,
@@ -219,6 +226,7 @@ class Server : public std::enable_shared_from_this<Server<PlainSessionType, SslS
tcp::acceptor acceptor_; tcp::acceptor acceptor_;
std::shared_ptr<AdminVerificationStrategy> adminVerification_; std::shared_ptr<AdminVerificationStrategy> adminVerification_;
std::uint32_t maxWsSendingQueueSize_; std::uint32_t maxWsSendingQueueSize_;
std::shared_ptr<ProxyIpResolver> proxyIpResolver_;
public: public:
/** /**
@@ -232,6 +240,7 @@ public:
* @param handler The server handler to use * @param handler The server handler to use
* @param adminVerification The admin verification strategy to use * @param adminVerification The admin verification strategy to use
* @param maxWsSendingQueueSize The maximum size of the sending queue for websocket * @param maxWsSendingQueueSize The maximum size of the sending queue for websocket
* @param proxyIpResolver The client ip resolver if a request was forwarded by a proxy
*/ */
Server( Server(
boost::asio::io_context& ioc, boost::asio::io_context& ioc,
@@ -241,7 +250,8 @@ public:
dosguard::DOSGuardInterface& dosGuard, dosguard::DOSGuardInterface& dosGuard,
std::shared_ptr<HandlerType> handler, std::shared_ptr<HandlerType> handler,
std::shared_ptr<AdminVerificationStrategy> adminVerification, std::shared_ptr<AdminVerificationStrategy> adminVerification,
std::uint32_t maxWsSendingQueueSize std::uint32_t maxWsSendingQueueSize,
ProxyIpResolver proxyIpResolver
) )
: ioc_(std::ref(ioc)) : ioc_(std::ref(ioc))
, ctx_(std::move(ctx)) , ctx_(std::move(ctx))
@@ -251,6 +261,7 @@ public:
, acceptor_(boost::asio::make_strand(ioc)) , acceptor_(boost::asio::make_strand(ioc))
, adminVerification_(std::move(adminVerification)) , adminVerification_(std::move(adminVerification))
, maxWsSendingQueueSize_(maxWsSendingQueueSize) , maxWsSendingQueueSize_(maxWsSendingQueueSize)
, proxyIpResolver_(std::make_shared<ProxyIpResolver>(std::move(proxyIpResolver)))
{ {
boost::beast::error_code ec; boost::beast::error_code ec;
@@ -310,7 +321,8 @@ private:
dosGuard_, dosGuard_,
handler_, handler_,
adminVerification_, adminVerification_,
maxWsSendingQueueSize_ maxWsSendingQueueSize_,
proxyIpResolver_
) )
->run(); ->run();
} }
@@ -364,6 +376,8 @@ makeHttpServer(
// each ledger. we allow user delay 3 ledgers by default // each ledger. we allow user delay 3 ledgers by default
auto const maxWsSendingQueueSize = serverConfig.get<uint32_t>("ws_max_sending_queue_size"); auto const maxWsSendingQueueSize = serverConfig.get<uint32_t>("ws_max_sending_queue_size");
auto proxyIpResolver = ProxyIpResolver::fromConfig(config);
auto server = std::make_shared<HttpServer<HandlerType>>( auto server = std::make_shared<HttpServer<HandlerType>>(
ioc, ioc,
std::move(expectedSslContext).value(), std::move(expectedSslContext).value(),
@@ -372,7 +386,8 @@ makeHttpServer(
dosGuard, dosGuard,
handler, handler,
std::move(expectedAdminVerification).value(), std::move(expectedAdminVerification).value(),
maxWsSendingQueueSize maxWsSendingQueueSize,
std::move(proxyIpResolver)
); );
server->run(); server->run();

View File

@@ -21,6 +21,7 @@
#include "util/Taggable.hpp" #include "util/Taggable.hpp"
#include "web/AdminVerificationStrategy.hpp" #include "web/AdminVerificationStrategy.hpp"
#include "web/ProxyIpResolver.hpp"
#include "web/SslWsSession.hpp" #include "web/SslWsSession.hpp"
#include "web/dosguard/DOSGuardInterface.hpp" #include "web/dosguard/DOSGuardInterface.hpp"
#include "web/impl/HttpBase.hpp" #include "web/impl/HttpBase.hpp"
@@ -70,6 +71,7 @@ public:
* @param socket The socket. Ownership is transferred to HttpSession * @param socket The socket. Ownership is transferred to HttpSession
* @param ip Client's IP address * @param ip Client's IP address
* @param adminVerification The admin verification strategy to use * @param adminVerification The admin verification strategy to use
* @param proxyIpResolver The client ip resolver if a request was forwarded by a proxy
* @param ctx The SSL context * @param ctx The SSL context
* @param tagFactory A factory that is used to generate tags to track requests and sessions * @param tagFactory A factory that is used to generate tags to track requests and sessions
* @param dosGuard The denial of service guard to use * @param dosGuard The denial of service guard to use
@@ -81,6 +83,7 @@ public:
tcp::socket&& socket, tcp::socket&& socket,
std::string const& ip, std::string const& ip,
std::shared_ptr<AdminVerificationStrategy> const& adminVerification, std::shared_ptr<AdminVerificationStrategy> const& adminVerification,
std::shared_ptr<ProxyIpResolver> proxyIpResolver,
boost::asio::ssl::context& ctx, boost::asio::ssl::context& ctx,
std::reference_wrapper<util::TagDecoratorFactory const> tagFactory, std::reference_wrapper<util::TagDecoratorFactory const> tagFactory,
std::reference_wrapper<dosguard::DOSGuardInterface> dosGuard, std::reference_wrapper<dosguard::DOSGuardInterface> dosGuard,
@@ -92,6 +95,7 @@ public:
ip, ip,
tagFactory, tagFactory,
adminVerification, adminVerification,
std::move(proxyIpResolver),
dosGuard, dosGuard,
handler, handler,
std::move(buffer) std::move(buffer)
@@ -173,7 +177,7 @@ public:
{ {
std::make_shared<SslWsUpgrader<HandlerType>>( std::make_shared<SslWsUpgrader<HandlerType>>(
std::move(stream_), std::move(stream_),
this->clientIp, this->clientIp_,
tagFactory_, tagFactory_,
this->dosGuard_, this->dosGuard_,
this->handler_, this->handler_,

View File

@@ -26,6 +26,7 @@
#include "util/log/Logger.hpp" #include "util/log/Logger.hpp"
#include "util/prometheus/Http.hpp" #include "util/prometheus/Http.hpp"
#include "web/AdminVerificationStrategy.hpp" #include "web/AdminVerificationStrategy.hpp"
#include "web/ProxyIpResolver.hpp"
#include "web/SubscriptionContextInterface.hpp" #include "web/SubscriptionContextInterface.hpp"
#include "web/dosguard/DOSGuardInterface.hpp" #include "web/dosguard/DOSGuardInterface.hpp"
#include "web/interface/Concepts.hpp" #include "web/interface/Concepts.hpp"
@@ -120,6 +121,7 @@ class HttpBase : public ConnectionBase {
std::shared_ptr<void> res_; std::shared_ptr<void> res_;
SendLambda sender_; SendLambda sender_;
std::shared_ptr<AdminVerificationStrategy> adminVerification_; std::shared_ptr<AdminVerificationStrategy> adminVerification_;
std::shared_ptr<ProxyIpResolver> proxyIpResolver_;
protected: protected:
boost::beast::flat_buffer buffer_; boost::beast::flat_buffer buffer_;
@@ -164,6 +166,7 @@ public:
std::string const& ip, std::string const& ip,
std::reference_wrapper<util::TagDecoratorFactory const> tagFactory, std::reference_wrapper<util::TagDecoratorFactory const> tagFactory,
std::shared_ptr<AdminVerificationStrategy> adminVerification, std::shared_ptr<AdminVerificationStrategy> adminVerification,
std::shared_ptr<ProxyIpResolver> proxyIpResolver,
std::reference_wrapper<dosguard::DOSGuardInterface> dosGuard, std::reference_wrapper<dosguard::DOSGuardInterface> dosGuard,
std::shared_ptr<HandlerType> handler, std::shared_ptr<HandlerType> handler,
boost::beast::flat_buffer buffer boost::beast::flat_buffer buffer
@@ -171,6 +174,7 @@ public:
: ConnectionBase(tagFactory, ip) : ConnectionBase(tagFactory, ip)
, sender_(*this) , sender_(*this)
, adminVerification_(std::move(adminVerification)) , adminVerification_(std::move(adminVerification))
, proxyIpResolver_(std::move(proxyIpResolver))
, buffer_(std::move(buffer)) , buffer_(std::move(buffer))
, dosGuard_(dosGuard) , dosGuard_(dosGuard)
, handler_(std::move(handler)) , handler_(std::move(handler))
@@ -183,7 +187,7 @@ public:
{ {
LOG(perfLog_.debug()) << tag() << "http session closed"; LOG(perfLog_.debug()) << tag() << "http session closed";
if (not upgraded) if (not upgraded)
dosGuard_.get().decrement(this->clientIp); dosGuard_.get().decrement(clientIp_);
} }
void void
@@ -218,11 +222,19 @@ public:
if (req_.method() == http::verb::get and req_.target() == "/health") if (req_.method() == http::verb::get and req_.target() == "/health")
return sender_(httpResponse(http::status::ok, "text/html", kHEALTH_CHECK_HTML)); return sender_(httpResponse(http::status::ok, "text/html", kHEALTH_CHECK_HTML));
if (auto resolvedIp = proxyIpResolver_->resolveClientIp(clientIp_, req_); resolvedIp != clientIp_) {
LOG(log_.info()) << tag() << "Detected a forwarded request from proxy. Proxy ip: " << clientIp_
<< ". Resolved client ip: " << resolvedIp;
dosGuard_.get().decrement(clientIp_);
clientIp_ = std::move(resolvedIp);
dosGuard_.get().increment(clientIp_);
}
// Update isAdmin property of the connection // Update isAdmin property of the connection
ConnectionBase::isAdmin_ = adminVerification_->isAdmin(req_, this->clientIp); ConnectionBase::isAdmin_ = adminVerification_->isAdmin(req_, clientIp_);
if (boost::beast::websocket::is_upgrade(req_)) { if (boost::beast::websocket::is_upgrade(req_)) {
if (dosGuard_.get().isOk(this->clientIp)) { if (dosGuard_.get().isOk(clientIp_)) {
// Disable the timeout. The websocket::stream uses its own timeout settings. // Disable the timeout. The websocket::stream uses its own timeout settings.
boost::beast::get_lowest_layer(derived().stream()).expires_never(); boost::beast::get_lowest_layer(derived().stream()).expires_never();
@@ -240,7 +252,7 @@ public:
return sender_(httpResponse(http::status::bad_request, "text/html", "Expected a POST request")); return sender_(httpResponse(http::status::bad_request, "text/html", "Expected a POST request"));
} }
LOG(log_.info()) << tag() << "Received request from ip = " << clientIp; LOG(log_.info()) << tag() << "Received request from ip = " << clientIp_;
try { try {
(*handler_)(req_.body(), derived().shared_from_this()); (*handler_)(req_.body(), derived().shared_from_this());
@@ -271,7 +283,7 @@ public:
void void
send(std::string&& msg, http::status status = http::status::ok) override send(std::string&& msg, http::status status = http::status::ok) override
{ {
if (!dosGuard_.get().add(clientIp, msg.size())) { if (!dosGuard_.get().add(clientIp_, msg.size())) {
auto jsonResponse = boost::json::parse(msg).as_object(); auto jsonResponse = boost::json::parse(msg).as_object();
jsonResponse["warning"] = "load"; jsonResponse["warning"] = "load";
if (jsonResponse.contains("warnings") && jsonResponse["warnings"].is_array()) { if (jsonResponse.contains("warnings") && jsonResponse["warnings"].is_array()) {

View File

@@ -124,7 +124,7 @@ public:
~WsBase() override ~WsBase() override
{ {
LOG(perfLog_.debug()) << tag() << "session closed"; LOG(perfLog_.debug()) << tag() << "session closed";
dosGuard_.get().decrement(clientIp); dosGuard_.get().decrement(clientIp_);
} }
Derived<HandlerType>& Derived<HandlerType>&
@@ -217,7 +217,7 @@ public:
void void
send(std::string&& msg, http::status) override send(std::string&& msg, http::status) override
{ {
if (!dosGuard_.get().add(clientIp, msg.size())) { if (!dosGuard_.get().add(clientIp_, msg.size())) {
auto jsonResponse = boost::json::parse(msg).as_object(); auto jsonResponse = boost::json::parse(msg).as_object();
jsonResponse["warning"] = "load"; jsonResponse["warning"] = "load";
@@ -283,7 +283,7 @@ public:
if (ec) if (ec)
return wsFail(ec, "read"); return wsFail(ec, "read");
LOG(perfLog_.info()) << tag() << "Received request from ip = " << this->clientIp; LOG(perfLog_.info()) << tag() << "Received request from ip = " << clientIp_;
std::string requestStr{static_cast<char const*>(buffer_.data().data()), buffer_.size()}; std::string requestStr{static_cast<char const*>(buffer_.data().data()), buffer_.size()};

View File

@@ -45,9 +45,9 @@ struct ConnectionBase : public util::Taggable {
protected: protected:
boost::system::error_code ec_; boost::system::error_code ec_;
bool isAdmin_ = false; bool isAdmin_ = false;
std::string clientIp_;
public: public:
std::string const clientIp;
bool upgraded = false; bool upgraded = false;
/** /**
@@ -57,7 +57,7 @@ public:
* @param ip The IP address of the connected peer * @param ip The IP address of the connected peer
*/ */
ConnectionBase(util::TagDecoratorFactory const& tagFactory, std::string ip) ConnectionBase(util::TagDecoratorFactory const& tagFactory, std::string ip)
: Taggable(tagFactory), clientIp(std::move(ip)) : Taggable(tagFactory), clientIp_(std::move(ip))
{ {
} }
@@ -119,5 +119,16 @@ public:
{ {
return isAdmin_; return isAdmin_;
} }
/**
* @brief Get the IP address of the client.
*
* @return The IP address of the client.
*/
[[nodiscard]] std::string const&
clientIp() const
{
return clientIp_;
}
}; };
} // namespace web } // namespace web

View File

@@ -34,6 +34,7 @@
#include <memory> #include <memory>
#include <optional> #include <optional>
#include <string> #include <string>
#include <utility>
namespace web::ng { namespace web::ng {
@@ -70,6 +71,17 @@ public:
std::string const& std::string const&
ip() const; ip() const;
/**
* @brief Set the ip of the client.
*
* @param newIp The new ip to set.
*/
void
setIp(std::string newIp)
{
ip_ = std::move(newIp);
}
/** /**
* @brief Get whether the client is an admin. * @brief Get whether the client is an admin.
* *

View File

@@ -25,6 +25,7 @@
#include "util/config/ConfigDefinition.hpp" #include "util/config/ConfigDefinition.hpp"
#include "util/config/ObjectView.hpp" #include "util/config/ObjectView.hpp"
#include "util/log/Logger.hpp" #include "util/log/Logger.hpp"
#include "web/ProxyIpResolver.hpp"
#include "web/ng/Connection.hpp" #include "web/ng/Connection.hpp"
#include "web/ng/MessageHandler.hpp" #include "web/ng/MessageHandler.hpp"
#include "web/ng/ProcessingPolicy.hpp" #include "web/ng/ProcessingPolicy.hpp"
@@ -205,16 +206,16 @@ Server::Server(
ProcessingPolicy processingPolicy, ProcessingPolicy processingPolicy,
std::optional<size_t> parallelRequestLimit, std::optional<size_t> parallelRequestLimit,
util::TagDecoratorFactory tagDecoratorFactory, util::TagDecoratorFactory tagDecoratorFactory,
ProxyIpResolver proxyIpResolver,
std::optional<size_t> maxSubscriptionSendQueueSize, std::optional<size_t> maxSubscriptionSendQueueSize,
OnConnectCheck onConnectCheck, Hooks hooks
OnDisconnectHook onDisconnectHook
) )
: ctx_{ctx} : ctx_{ctx}
, sslContext_{std::move(sslContext)} , sslContext_{std::move(sslContext)}
, tagDecoratorFactory_{tagDecoratorFactory} , tagDecoratorFactory_{tagDecoratorFactory}
, connectionHandler_{processingPolicy, parallelRequestLimit, tagDecoratorFactory_, maxSubscriptionSendQueueSize, std::move(onDisconnectHook)} , connectionHandler_{processingPolicy, parallelRequestLimit, tagDecoratorFactory_, maxSubscriptionSendQueueSize, std::move(proxyIpResolver), std::move(hooks.onDisconnectHook), std::move(hooks.onIpChangeHook)}
, endpoint_{std::move(endpoint)} , endpoint_{std::move(endpoint)}
, onConnectCheck_{std::move(onConnectCheck)} , onConnectCheck_{std::move(hooks.onConnectCheck)}
{ {
} }
@@ -337,6 +338,7 @@ std::expected<Server, std::string>
makeServer( makeServer(
util::config::ClioConfigDefinition const& config, util::config::ClioConfigDefinition const& config,
Server::OnConnectCheck onConnectCheck, Server::OnConnectCheck onConnectCheck,
Server::OnIpChangeHook onIpChangeHook,
Server::OnDisconnectHook onDisconnectHook, Server::OnDisconnectHook onDisconnectHook,
boost::asio::io_context& context boost::asio::io_context& context
) )
@@ -365,6 +367,8 @@ makeServer(
auto const maxSubscriptionSendQueueSize = serverConfig.get<size_t>("ws_max_sending_queue_size"); auto const maxSubscriptionSendQueueSize = serverConfig.get<size_t>("ws_max_sending_queue_size");
auto proxyIpResolver = ProxyIpResolver::fromConfig(config);
return std::expected<Server, std::string>{ return std::expected<Server, std::string>{
std::in_place, std::in_place,
context, context,
@@ -373,9 +377,13 @@ makeServer(
processingPolicy, processingPolicy,
parallelRequestLimit, parallelRequestLimit,
util::TagDecoratorFactory(config), util::TagDecoratorFactory(config),
std::move(proxyIpResolver),
maxSubscriptionSendQueueSize, maxSubscriptionSendQueueSize,
std::move(onConnectCheck), Server::Hooks{
std::move(onDisconnectHook) .onConnectCheck = std::move(onConnectCheck),
.onIpChangeHook = std::move(onIpChangeHook),
.onDisconnectHook = std::move(onDisconnectHook)
}
}; };
} }

View File

@@ -22,6 +22,7 @@
#include "util/Taggable.hpp" #include "util/Taggable.hpp"
#include "util/config/ConfigDefinition.hpp" #include "util/config/ConfigDefinition.hpp"
#include "util/log/Logger.hpp" #include "util/log/Logger.hpp"
#include "web/ProxyIpResolver.hpp"
#include "web/ng/Connection.hpp" #include "web/ng/Connection.hpp"
#include "web/ng/MessageHandler.hpp" #include "web/ng/MessageHandler.hpp"
#include "web/ng/ProcessingPolicy.hpp" #include "web/ng/ProcessingPolicy.hpp"
@@ -62,11 +63,22 @@ public:
*/ */
using OnConnectCheck = std::function<std::expected<void, Response>(Connection const&)>; using OnConnectCheck = std::function<std::expected<void, Response>(Connection const&)>;
using OnIpChangeHook = impl::ConnectionHandler::OnIpChangeHook;
/** /**
* @brief Hook called when any connection disconnects * @brief Hook called when any connection disconnects
*/ */
using OnDisconnectHook = impl::ConnectionHandler::OnDisconnectHook; using OnDisconnectHook = impl::ConnectionHandler::OnDisconnectHook;
/**
* @brief A struct that holds all the hooks for the server.
*/
struct Hooks {
OnConnectCheck onConnectCheck;
OnIpChangeHook onIpChangeHook;
OnDisconnectHook onDisconnectHook;
};
private: private:
util::Logger log_{"WebServer"}; util::Logger log_{"WebServer"};
util::Logger perfLog_{"Performance"}; util::Logger perfLog_{"Performance"};
@@ -94,9 +106,9 @@ public:
* @param parallelRequestLimit The limit of requests for one connection that can be processed in parallel. Only used * @param parallelRequestLimit The limit of requests for one connection that can be processed in parallel. Only used
* if processingPolicy is parallel. * if processingPolicy is parallel.
* @param tagDecoratorFactory The tag decorator factory. * @param tagDecoratorFactory The tag decorator factory.
* @param proxyIpResolver The client ip resolver if a request was forwarded by a proxy
* @param maxSubscriptionSendQueueSize The maximum size of the subscription send queue. * @param maxSubscriptionSendQueueSize The maximum size of the subscription send queue.
* @param onConnectCheck The check to perform on each connection. * @param hooks The server hooks
* @param onDisconnectHook The hook to call on each disconnection.
*/ */
Server( Server(
boost::asio::io_context& ctx, boost::asio::io_context& ctx,
@@ -105,9 +117,9 @@ public:
ProcessingPolicy processingPolicy, ProcessingPolicy processingPolicy,
std::optional<size_t> parallelRequestLimit, std::optional<size_t> parallelRequestLimit,
util::TagDecoratorFactory tagDecoratorFactory, util::TagDecoratorFactory tagDecoratorFactory,
ProxyIpResolver proxyIpResolver,
std::optional<size_t> maxSubscriptionSendQueueSize, std::optional<size_t> maxSubscriptionSendQueueSize,
OnConnectCheck onConnectCheck, Hooks hooks
OnDisconnectHook onDisconnectHook
); );
/** /**
@@ -185,6 +197,7 @@ std::expected<Server, std::string>
makeServer( makeServer(
util::config::ClioConfigDefinition const& config, util::config::ClioConfigDefinition const& config,
Server::OnConnectCheck onConnectCheck, Server::OnConnectCheck onConnectCheck,
Server::OnIpChangeHook onIpChangeHook,
Server::OnDisconnectHook onDisconnectHook, Server::OnDisconnectHook onDisconnectHook,
boost::asio::io_context& context boost::asio::io_context& context
); );

View File

@@ -23,6 +23,7 @@
#include "util/CoroutineGroup.hpp" #include "util/CoroutineGroup.hpp"
#include "util/Taggable.hpp" #include "util/Taggable.hpp"
#include "util/log/Logger.hpp" #include "util/log/Logger.hpp"
#include "web/ProxyIpResolver.hpp"
#include "web/SubscriptionContextInterface.hpp" #include "web/SubscriptionContextInterface.hpp"
#include "web/ng/Connection.hpp" #include "web/ng/Connection.hpp"
#include "web/ng/Error.hpp" #include "web/ng/Error.hpp"
@@ -90,13 +91,17 @@ ConnectionHandler::ConnectionHandler(
std::optional<size_t> maxParallelRequests, std::optional<size_t> maxParallelRequests,
util::TagDecoratorFactory& tagFactory, util::TagDecoratorFactory& tagFactory,
std::optional<size_t> maxSubscriptionSendQueueSize, std::optional<size_t> maxSubscriptionSendQueueSize,
OnDisconnectHook onDisconnectHook ProxyIpResolver proxyIpResolver,
OnDisconnectHook onDisconnectHook,
OnIpChangeHook onIpChangeHook
) )
: processingPolicy_{processingPolicy} : processingPolicy_{processingPolicy}
, maxParallelRequests_{maxParallelRequests} , maxParallelRequests_{maxParallelRequests}
, tagFactory_{tagFactory} , tagFactory_{tagFactory}
, maxSubscriptionSendQueueSize_{maxSubscriptionSendQueueSize} , maxSubscriptionSendQueueSize_{maxSubscriptionSendQueueSize}
, proxyIpResolver_(std::move(proxyIpResolver))
, onDisconnectHook_{std::move(onDisconnectHook)} , onDisconnectHook_{std::move(onDisconnectHook)}
, onIpChangeHook_(std::move(onIpChangeHook))
{ {
} }
@@ -277,9 +282,9 @@ ConnectionHandler::sequentRequestResponseLoop(
if (not expectedRequest) if (not expectedRequest)
return handleError(expectedRequest.error(), connection); return handleError(expectedRequest.error(), connection);
LOG(log_.info()) << connection.tag() << "Received request from ip = " << connection.ip(); resolveClientIp(connection, *expectedRequest);
auto maybeReturnValue = processRequest(connection, subscriptionContext, expectedRequest.value(), yield); auto maybeReturnValue = processRequest(connection, subscriptionContext, *expectedRequest, yield);
if (maybeReturnValue.has_value()) if (maybeReturnValue.has_value())
return maybeReturnValue.value(); return maybeReturnValue.value();
} }
@@ -308,6 +313,8 @@ ConnectionHandler::parallelRequestResponseLoop(
break; break;
} }
resolveClientIp(connection, *expectedRequest);
if (not tasksGroup.isFull()) { if (not tasksGroup.isFull()) {
bool const spawnSuccess = tasksGroup.spawn( bool const spawnSuccess = tasksGroup.spawn(
yield, // spawn on the same strand yield, // spawn on the same strand
@@ -385,4 +392,16 @@ ConnectionHandler::handleRequest(
} }
} }
void
ConnectionHandler::resolveClientIp(Connection& connection, Request const& request) const
{
if (auto resolvedClientIp = proxyIpResolver_.resolveClientIp(connection.ip(), request.httpHeaders());
resolvedClientIp != connection.ip()) {
LOG(log_.info()) << connection.tag() << "Detected a forwarded request from proxy. Proxy ip: " << connection.ip()
<< ". Resolved client ip: " << resolvedClientIp;
onIpChangeHook_(connection.ip(), resolvedClientIp);
connection.setIp(std::move(resolvedClientIp));
}
}
} // namespace web::ng::impl } // namespace web::ng::impl

View File

@@ -26,6 +26,7 @@
#include "util/prometheus/Gauge.hpp" #include "util/prometheus/Gauge.hpp"
#include "util/prometheus/Label.hpp" #include "util/prometheus/Label.hpp"
#include "util/prometheus/Prometheus.hpp" #include "util/prometheus/Prometheus.hpp"
#include "web/ProxyIpResolver.hpp"
#include "web/SubscriptionContextInterface.hpp" #include "web/SubscriptionContextInterface.hpp"
#include "web/ng/Connection.hpp" #include "web/ng/Connection.hpp"
#include "web/ng/Error.hpp" #include "web/ng/Error.hpp"
@@ -52,6 +53,7 @@ namespace web::ng::impl {
class ConnectionHandler { class ConnectionHandler {
public: public:
using OnDisconnectHook = std::function<void(Connection const&)>; using OnDisconnectHook = std::function<void(Connection const&)>;
using OnIpChangeHook = std::function<void(std::string const&, std::string const&)>;
using TargetToHandlerMap = std::unordered_map<std::string, MessageHandler, util::StringHash, std::equal_to<>>; using TargetToHandlerMap = std::unordered_map<std::string, MessageHandler, util::StringHash, std::equal_to<>>;
private: private:
@@ -64,7 +66,10 @@ private:
std::reference_wrapper<util::TagDecoratorFactory> tagFactory_; std::reference_wrapper<util::TagDecoratorFactory> tagFactory_;
std::optional<size_t> maxSubscriptionSendQueueSize_; std::optional<size_t> maxSubscriptionSendQueueSize_;
ProxyIpResolver proxyIpResolver_;
OnDisconnectHook onDisconnectHook_; OnDisconnectHook onDisconnectHook_;
OnIpChangeHook onIpChangeHook_;
TargetToHandlerMap getHandlers_; TargetToHandlerMap getHandlers_;
TargetToHandlerMap postHandlers_; TargetToHandlerMap postHandlers_;
@@ -84,7 +89,9 @@ public:
std::optional<size_t> maxParallelRequests, std::optional<size_t> maxParallelRequests,
util::TagDecoratorFactory& tagFactory, util::TagDecoratorFactory& tagFactory,
std::optional<size_t> maxSubscriptionSendQueueSize, std::optional<size_t> maxSubscriptionSendQueueSize,
OnDisconnectHook onDisconnectHook ProxyIpResolver proxyIpResolver,
OnDisconnectHook onDisconnectHook,
OnIpChangeHook onIpChangeHook
); );
ConnectionHandler(ConnectionHandler&&) = delete; ConnectionHandler(ConnectionHandler&&) = delete;
@@ -159,6 +166,9 @@ private:
Request const& request, Request const& request,
boost::asio::yield_context yield boost::asio::yield_context yield
); );
void
resolveClientIp(Connection& connection, Request const& request) const;
}; };
} // namespace web::ng::impl } // namespace web::ng::impl

View File

@@ -50,6 +50,7 @@
#include <string> #include <string>
#include <string_view> #include <string_view>
#include <utility> #include <utility>
#include <variant>
#include <vector> #include <vector>
namespace http = boost::beast::http; namespace http = boost::beast::http;
@@ -81,8 +82,8 @@ syncRequest(
req.set(http::field::host, host); req.set(http::field::host, host);
req.set(http::field::user_agent, BOOST_BEAST_VERSION_STRING); req.set(http::field::user_agent, BOOST_BEAST_VERSION_STRING);
for (auto& header : additionalHeaders) { for (auto const& header : additionalHeaders) {
req.set(header.name, header.value); std::visit([&header, &req](auto const& name) { req.set(name, header.value); }, header.name);
} }
req.target(target); req.target(target);
@@ -106,6 +107,10 @@ WebHeader::WebHeader(http::field name, std::string value) : name(name), value(st
{ {
} }
WebHeader::WebHeader(std::string_view name, std::string value) : name(std::string{name}), value(std::move(value))
{
}
std::pair<boost::beast::http::status, std::string> std::pair<boost::beast::http::status, std::string>
HttpSyncClient::post( HttpSyncClient::post(
std::string const& host, std::string const& host,

View File

@@ -35,12 +35,14 @@
#include <string> #include <string>
#include <string_view> #include <string_view>
#include <utility> #include <utility>
#include <variant>
#include <vector> #include <vector>
struct WebHeader { struct WebHeader {
WebHeader(boost::beast::http::field name, std::string value); WebHeader(boost::beast::http::field name, std::string value);
WebHeader(std::string_view name, std::string value);
boost::beast::http::field name; std::variant<boost::beast::http::field, std::string> name;
std::string value; std::string value;
}; };

View File

@@ -46,6 +46,7 @@
#include <string> #include <string>
#include <string_view> #include <string_view>
#include <utility> #include <utility>
#include <variant>
#include <vector> #include <vector>
namespace http = boost::beast::http; namespace http = boost::beast::http;
@@ -69,7 +70,7 @@ WebSocketSyncClient::connect(std::string const& host, std::string const& port, s
[additionalHeaders = std::move(additionalHeaders)](boost::beast::websocket::request_type& req) { [additionalHeaders = std::move(additionalHeaders)](boost::beast::websocket::request_type& req) {
req.set(http::field::user_agent, std::string(BOOST_BEAST_VERSION_STRING) + " websocket-client-coro"); req.set(http::field::user_agent, std::string(BOOST_BEAST_VERSION_STRING) + " websocket-client-coro");
for (auto const& header : additionalHeaders) { for (auto const& header : additionalHeaders) {
req.set(header.name, header.value); std::visit([&header, &req](auto const& name) { req.set(name, header.value); }, header.name);
} }
} }
) )
@@ -164,7 +165,7 @@ WebSocketAsyncClient::connect(
boost::beast::websocket::stream_base::decorator( boost::beast::websocket::stream_base::decorator(
[additionalHeaders = std::move(additionalHeaders)](boost::beast::websocket::request_type& req) { [additionalHeaders = std::move(additionalHeaders)](boost::beast::websocket::request_type& req) {
for (auto const& header : additionalHeaders) { for (auto const& header : additionalHeaders) {
req.set(header.name, header.value); std::visit([&header, &req](auto const& name) { req.set(name, header.value); }, header.name);
} }
} }
) )

View File

@@ -179,6 +179,7 @@ target_sources(
util/RetryTests.cpp util/RetryTests.cpp
util/ScopeGuardTests.cpp util/ScopeGuardTests.cpp
util/SignalsHandlerTests.cpp util/SignalsHandlerTests.cpp
util/ShasumTests.cpp
util/SpawnTests.cpp util/SpawnTests.cpp
util/StopHelperTests.cpp util/StopHelperTests.cpp
util/TimeUtilsTests.cpp util/TimeUtilsTests.cpp
@@ -201,6 +202,7 @@ target_sources(
web/ng/impl/HttpConnectionTests.cpp web/ng/impl/HttpConnectionTests.cpp
web/ng/impl/ServerSslContextTests.cpp web/ng/impl/ServerSslContextTests.cpp
web/ng/impl/WsConnectionTests.cpp web/ng/impl/WsConnectionTests.cpp
web/ProxyIpResolverTests.cpp
web/RPCServerHandlerTests.cpp web/RPCServerHandlerTests.cpp
web/ServerTests.cpp web/ServerTests.cpp
web/SubscriptionContextTests.cpp web/SubscriptionContextTests.cpp

View File

@@ -92,6 +92,21 @@ TEST_F(OnConnectCheckTests, RateLimited)
EXPECT_EQ(httpResponse.body(), "Too many requests"); EXPECT_EQ(httpResponse.body(), "Too many requests");
} }
struct IpChangeHookTests : WebHandlersTest {
IpChangeHook ipChangeHook{dosGuardMock};
};
TEST_F(IpChangeHookTests, CallsDecrementAndIncrement)
{
std::string const oldIp = "old ip";
std::string const newIp = "new ip";
EXPECT_CALL(dosGuardMock, decrement(oldIp));
EXPECT_CALL(dosGuardMock, increment(newIp));
ipChangeHook(oldIp, newIp);
}
struct DisconnectHookTests : WebHandlersTest { struct DisconnectHookTests : WebHandlersTest {
DisconnectHook disconnectHook{dosGuardMock}; DisconnectHook disconnectHook{dosGuardMock};
}; };

View File

@@ -0,0 +1,47 @@
//------------------------------------------------------------------------------
/*
This file is part of clio: https://github.com/XRPLF/clio
Copyright (c) 2025, the clio developers.
Permission to use, copy, modify, and distribute this software for any
purpose with or without fee is hereby granted, provided that the above
copyright notice and this permission notice appear in all copies.
THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
*/
//==============================================================================
#include "util/Shasum.hpp"
#include <gtest/gtest.h>
#include <xrpl/basics/base_uint.h>
using namespace util;
struct ShasumTest : testing::Test {
static constexpr auto kEMPTY_HASH = "E3B0C44298FC1C149AFBF4C8996FB92427AE41E4649B934CA495991B7852B855";
static constexpr auto kHELLO_WORLD_HASH = "B94D27B9934D3E08A52E52D7DA7DABFAC484EFE37A5380EE9088F7ACE2EFCDE9";
};
TEST_F(ShasumTest, sha256sum)
{
ripple::uint256 expected;
ASSERT_TRUE(expected.parseHex(kEMPTY_HASH));
EXPECT_EQ(sha256sum(""), expected);
ASSERT_TRUE(expected.parseHex(kHELLO_WORLD_HASH));
EXPECT_EQ(sha256sum("hello world"), expected);
}
TEST_F(ShasumTest, sha256sumString)
{
EXPECT_EQ(sha256sumString(""), kEMPTY_HASH);
EXPECT_EQ(sha256sumString("hello world"), kHELLO_WORLD_HASH);
}

View File

@@ -0,0 +1,214 @@
/*
This file is part of clio: https://github.com/XRPLF/clio
Copyright (c) 2025, the clio developers.
Permission to use, copy, modify, and distribute this software for any
purpose with or without fee is hereby granted, provided that the above
copyright notice and this permission notice appear in all copies.
THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
*/
//==============================================================================
#include "util/NameGenerator.hpp"
#include "util/config/Array.hpp"
#include "util/config/ConfigDefinition.hpp"
#include "util/config/ConfigFileJson.hpp"
#include "util/config/ConfigValue.hpp"
#include "util/config/Types.hpp"
#include "web/ProxyIpResolver.hpp"
#include <boost/beast/http/field.hpp>
#include <boost/json/parse.hpp>
#include <fmt/format.h>
#include <gtest/gtest.h>
#include <string>
#include <unordered_set>
#include <utility>
#include <vector>
namespace http = boost::beast::http;
using namespace web;
struct ProxyIpResolverTestParams {
std::string testName;
std::unordered_set<std::string> proxyIps;
std::unordered_set<std::string> proxyTokens;
std::vector<std::pair<std::string, std::string>> headers;
std::string connectionIp;
std::string expectedIp;
};
class ProxyIpResolverTest : public ::testing::TestWithParam<ProxyIpResolverTestParams> {};
TEST_F(ProxyIpResolverTest, FromConfig)
{
using namespace util::config;
ClioConfigDefinition config{{
{"server.proxy.ips.[]", Array{ConfigValue{ConfigType::String}}},
{"server.proxy.tokens.[]", Array{ConfigValue{ConfigType::String}}},
}};
auto const proxyIp = "1.2.3.4";
auto const clientIp = "5.6.7.8";
auto const proxyToken = "some_proxy_token";
auto const configStr = fmt::format(
R"({{
"server": {{
"proxy": {{
"ips": ["{}"],
"tokens": ["{}"]
}}
}}
}})",
proxyIp,
proxyToken
);
auto const err = config.parse(ConfigFileJson{boost::json::parse(configStr).as_object()});
ASSERT_FALSE(err.has_value());
auto const proxyIpResolver = ProxyIpResolver::fromConfig(config);
ProxyIpResolver::HttpHeaders headers;
EXPECT_EQ(proxyIpResolver.resolveClientIp(clientIp, headers), clientIp);
EXPECT_EQ(proxyIpResolver.resolveClientIp(proxyIp, headers), proxyIp);
headers.set(boost::beast::http::field::forwarded, fmt::format("for={}", clientIp));
EXPECT_EQ(proxyIpResolver.resolveClientIp(clientIp, headers), clientIp);
EXPECT_EQ(proxyIpResolver.resolveClientIp(proxyIp, headers), clientIp);
headers.set(ProxyIpResolver::kPROXY_TOKEN_HEADER, proxyToken);
EXPECT_EQ(proxyIpResolver.resolveClientIp(clientIp, headers), clientIp);
EXPECT_EQ(proxyIpResolver.resolveClientIp(proxyIp, headers), clientIp);
EXPECT_EQ(proxyIpResolver.resolveClientIp("127.0.0.1", headers), clientIp);
}
TEST_P(ProxyIpResolverTest, ResolveClientIp)
{
auto const& params = GetParam();
ProxyIpResolver resolver(params.proxyIps, params.proxyTokens);
ProxyIpResolver::HttpHeaders headers;
for (auto const& [key, value] : params.headers) {
headers.set(key, value);
}
EXPECT_EQ(resolver.resolveClientIp(params.connectionIp, headers), params.expectedIp);
}
INSTANTIATE_TEST_SUITE_P(
ProxyIpResolverTests,
ProxyIpResolverTest,
::testing::Values(
ProxyIpResolverTestParams{
.testName = "NoProxy",
.proxyIps = {},
.proxyTokens = {},
.headers = {},
.connectionIp = "1.2.3.4",
.expectedIp = "1.2.3.4"
},
ProxyIpResolverTestParams{
.testName = "TrustedProxyIpWithForwardedHeader",
.proxyIps = {"5.6.7.8"},
.proxyTokens = {},
.headers = {{std::string(http::to_string(http::field::forwarded)), "for=1.2.3.4"}},
.connectionIp = "5.6.7.8",
.expectedIp = "1.2.3.4"
},
ProxyIpResolverTestParams{
.testName = "TrustedProxyIpWithoutForwardedHeader",
.proxyIps = {"5.6.7.8"},
.proxyTokens = {},
.headers = {},
.connectionIp = "5.6.7.8",
.expectedIp = "5.6.7.8"
},
ProxyIpResolverTestParams{
.testName = "UntrustedProxyIpWithForwardedHeader",
.proxyIps = {},
.proxyTokens = {},
.headers = {{std::string(http::to_string(http::field::forwarded)), "for=1.2.3.4"}},
.connectionIp = "5.6.7.8",
.expectedIp = "5.6.7.8"
},
ProxyIpResolverTestParams{
.testName = "TrustedProxyTokenWithForwardedHeader",
.proxyIps = {},
.proxyTokens = {"test_token"},
.headers =
{{std::string(ProxyIpResolver::kPROXY_TOKEN_HEADER), "test_token"},
{std::string(http::to_string(http::field::forwarded)), "for=1.2.3.4"}},
.connectionIp = "5.6.7.8",
.expectedIp = "1.2.3.4"
},
ProxyIpResolverTestParams{
.testName = "TrustedProxyTokenWithoutForwardedHeader",
.proxyIps = {},
.proxyTokens = {"test_token"},
.headers = {{std::string(ProxyIpResolver::kPROXY_TOKEN_HEADER), "test_token"}},
.connectionIp = "5.6.7.8",
.expectedIp = "5.6.7.8"
},
ProxyIpResolverTestParams{
.testName = "UntrustedProxyTokenWithForwardedHeader",
.proxyIps = {},
.proxyTokens = {},
.headers =
{{std::string(ProxyIpResolver::kPROXY_TOKEN_HEADER), "test_token"},
{std::string(http::to_string(http::field::forwarded)), "for=1.2.3.4"}},
.connectionIp = "5.6.7.8",
.expectedIp = "5.6.7.8"
},
ProxyIpResolverTestParams{
.testName = "ForwardedHeaderWithAdditionalFields",
.proxyIps = {"5.6.7.8"},
.proxyTokens = {},
.headers =
{{std::string(http::to_string(http::field::forwarded)),
"by=203.0.113.43; for=1.2.3.4; host=example.com; proto=https"}},
.connectionIp = "5.6.7.8",
.expectedIp = "1.2.3.4"
},
ProxyIpResolverTestParams{
.testName = "ForwardedHeaderWithDifferentCase",
.proxyIps = {"5.6.7.8"},
.proxyTokens = {},
.headers = {{std::string(http::to_string(http::field::forwarded)), "For=1.2.3.4"}},
.connectionIp = "5.6.7.8",
.expectedIp = "1.2.3.4"
},
ProxyIpResolverTestParams{
.testName = "ForwardedHeaderWithoutFor",
.proxyIps = {"5.6.7.8"},
.proxyTokens = {},
.headers = {{std::string(http::to_string(http::field::forwarded)), "by=1.2.3.4"}},
.connectionIp = "5.6.7.8",
.expectedIp = "5.6.7.8"
},
ProxyIpResolverTestParams{
.testName = "ForwardedHeaderWithIpInQuotes",
.proxyIps = {"5.6.7.8"},
.proxyTokens = {},
.headers = {{std::string(http::to_string(http::field::forwarded)), "for=\"1.2.3.4\""}},
.connectionIp = "5.6.7.8",
.expectedIp = "1.2.3.4"
},
ProxyIpResolverTestParams{
.testName = "ForwardedHeaderIsIncorrect",
.proxyIps = {"5.6.7.8"},
.proxyTokens = {},
.headers = {{std::string(http::to_string(http::field::forwarded)), "for=\";some_other_text"}},
.connectionIp = "5.6.7.8",
.expectedIp = "5.6.7.8"
}
),
tests::util::kNAME_GENERATOR
);

View File

@@ -132,8 +132,8 @@ TEST_F(WebRPCServerHandlerTest, HTTPDefaultPath)
] ]
})JSON"; })JSON";
EXPECT_CALL(dosguard, isOk(session->clientIp)).WillOnce(testing::Return(true)); EXPECT_CALL(dosguard, isOk(session->clientIp())).WillOnce(testing::Return(true));
EXPECT_CALL(dosguard, request(session->clientIp, boost::json::parse(kREQUEST).as_object())) EXPECT_CALL(dosguard, request(session->clientIp(), boost::json::parse(kREQUEST).as_object()))
.WillOnce(testing::Return(true)); .WillOnce(testing::Return(true));
EXPECT_CALL(*rpcEngine, buildResponse(testing::_)) EXPECT_CALL(*rpcEngine, buildResponse(testing::_))
@@ -153,7 +153,7 @@ TEST_F(WebRPCServerHandlerTest, HTTPRejectedByDosguard)
"params": [{}] "params": [{}]
})JSON"; })JSON";
EXPECT_CALL(dosguard, isOk(session->clientIp)).WillOnce(testing::Return(false)); EXPECT_CALL(dosguard, isOk(session->clientIp())).WillOnce(testing::Return(false));
(*handler)(kREQUEST, session); (*handler)(kREQUEST, session);
EXPECT_EQ(session->slowDownCallsCounter, 1); EXPECT_EQ(session->slowDownCallsCounter, 1);
@@ -166,8 +166,8 @@ TEST_F(WebRPCServerHandlerTest, HTTPRejectedByDosguardAfterParsing)
"params": [{}] "params": [{}]
})JSON"; })JSON";
EXPECT_CALL(dosguard, isOk(session->clientIp)).WillOnce(testing::Return(true)); EXPECT_CALL(dosguard, isOk(session->clientIp())).WillOnce(testing::Return(true));
EXPECT_CALL(dosguard, request(session->clientIp, testing::_)).WillOnce(testing::Return(false)); EXPECT_CALL(dosguard, request(session->clientIp(), testing::_)).WillOnce(testing::Return(false));
(*handler)(kREQUEST, session); (*handler)(kREQUEST, session);
EXPECT_EQ(session->slowDownCallsCounter, 1); EXPECT_EQ(session->slowDownCallsCounter, 1);
@@ -198,8 +198,8 @@ TEST_F(WebRPCServerHandlerTest, WsNormalPath)
} }
] ]
})JSON"; })JSON";
EXPECT_CALL(dosguard, isOk(session->clientIp)).WillOnce(testing::Return(true)); EXPECT_CALL(dosguard, isOk(session->clientIp())).WillOnce(testing::Return(true));
EXPECT_CALL(dosguard, request(session->clientIp, boost::json::parse(kREQUEST).as_object())) EXPECT_CALL(dosguard, request(session->clientIp(), boost::json::parse(kREQUEST).as_object()))
.WillOnce(testing::Return(true)); .WillOnce(testing::Return(true));
EXPECT_CALL(*rpcEngine, buildResponse(testing::_)) EXPECT_CALL(*rpcEngine, buildResponse(testing::_))
@@ -221,7 +221,7 @@ TEST_F(WebRPCServerHandlerTest, WsRejectedByDosguard)
"api_version": 2 "api_version": 2
})JSON"; })JSON";
EXPECT_CALL(dosguard, isOk(session->clientIp)).WillOnce(testing::Return(false)); EXPECT_CALL(dosguard, isOk(session->clientIp())).WillOnce(testing::Return(false));
(*handler)(kREQUEST, session); (*handler)(kREQUEST, session);
EXPECT_EQ(session->slowDownCallsCounter, 1); EXPECT_EQ(session->slowDownCallsCounter, 1);
@@ -236,8 +236,8 @@ TEST_F(WebRPCServerHandlerTest, WsRejectedByDosguardAfterParsing)
"api_version": 2 "api_version": 2
})JSON"; })JSON";
EXPECT_CALL(dosguard, isOk(session->clientIp)).WillOnce(testing::Return(true)); EXPECT_CALL(dosguard, isOk(session->clientIp())).WillOnce(testing::Return(true));
EXPECT_CALL(dosguard, request(session->clientIp, boost::json::parse(kREQUEST).as_object())) EXPECT_CALL(dosguard, request(session->clientIp(), boost::json::parse(kREQUEST).as_object()))
.WillOnce(testing::Return(false)); .WillOnce(testing::Return(false));
(*handler)(kREQUEST, session); (*handler)(kREQUEST, session);
@@ -274,8 +274,8 @@ TEST_F(WebRPCServerHandlerTest, HTTPForwardedPath)
] ]
})JSON"; })JSON";
EXPECT_CALL(dosguard, isOk(session->clientIp)).WillOnce(testing::Return(true)); EXPECT_CALL(dosguard, isOk(session->clientIp())).WillOnce(testing::Return(true));
EXPECT_CALL(dosguard, request(session->clientIp, boost::json::parse(kREQUEST).as_object())) EXPECT_CALL(dosguard, request(session->clientIp(), boost::json::parse(kREQUEST).as_object()))
.WillOnce(testing::Return(true)); .WillOnce(testing::Return(true));
EXPECT_CALL(*rpcEngine, buildResponse(testing::_)) EXPECT_CALL(*rpcEngine, buildResponse(testing::_))
@@ -323,8 +323,8 @@ TEST_F(WebRPCServerHandlerTest, HTTPForwardedErrorPath)
] ]
})JSON"; })JSON";
EXPECT_CALL(dosguard, isOk(session->clientIp)).WillOnce(testing::Return(true)); EXPECT_CALL(dosguard, isOk(session->clientIp())).WillOnce(testing::Return(true));
EXPECT_CALL(dosguard, request(session->clientIp, boost::json::parse(kREQUEST).as_object())) EXPECT_CALL(dosguard, request(session->clientIp(), boost::json::parse(kREQUEST).as_object()))
.WillOnce(testing::Return(true)); .WillOnce(testing::Return(true));
EXPECT_CALL(*rpcEngine, buildResponse(testing::_)) EXPECT_CALL(*rpcEngine, buildResponse(testing::_))
@@ -370,8 +370,8 @@ TEST_F(WebRPCServerHandlerTest, WsForwardedPath)
] ]
})JSON"; })JSON";
EXPECT_CALL(dosguard, isOk(session->clientIp)).WillOnce(testing::Return(true)); EXPECT_CALL(dosguard, isOk(session->clientIp())).WillOnce(testing::Return(true));
EXPECT_CALL(dosguard, request(session->clientIp, boost::json::parse(kREQUEST).as_object())) EXPECT_CALL(dosguard, request(session->clientIp(), boost::json::parse(kREQUEST).as_object()))
.WillOnce(testing::Return(true)); .WillOnce(testing::Return(true));
EXPECT_CALL(*rpcEngine, buildResponse(testing::_)) EXPECT_CALL(*rpcEngine, buildResponse(testing::_))
@@ -420,8 +420,8 @@ TEST_F(WebRPCServerHandlerTest, WsForwardedErrorPath)
] ]
})JSON"; })JSON";
EXPECT_CALL(dosguard, isOk(session->clientIp)).WillOnce(testing::Return(true)); EXPECT_CALL(dosguard, isOk(session->clientIp())).WillOnce(testing::Return(true));
EXPECT_CALL(dosguard, request(session->clientIp, boost::json::parse(kREQUEST).as_object())) EXPECT_CALL(dosguard, request(session->clientIp(), boost::json::parse(kREQUEST).as_object()))
.WillOnce(testing::Return(true)); .WillOnce(testing::Return(true));
EXPECT_CALL(*rpcEngine, buildResponse(testing::_)) EXPECT_CALL(*rpcEngine, buildResponse(testing::_))
@@ -472,8 +472,8 @@ TEST_F(WebRPCServerHandlerTest, HTTPErrorPath)
] ]
})JSON"; })JSON";
EXPECT_CALL(dosguard, isOk(session->clientIp)).WillOnce(testing::Return(true)); EXPECT_CALL(dosguard, isOk(session->clientIp())).WillOnce(testing::Return(true));
EXPECT_CALL(dosguard, request(session->clientIp, boost::json::parse(kREQUEST_JSON).as_object())) EXPECT_CALL(dosguard, request(session->clientIp(), boost::json::parse(kREQUEST_JSON).as_object()))
.WillOnce(testing::Return(true)); .WillOnce(testing::Return(true));
EXPECT_CALL(*rpcEngine, buildResponse(testing::_)) EXPECT_CALL(*rpcEngine, buildResponse(testing::_))
@@ -521,8 +521,8 @@ TEST_F(WebRPCServerHandlerTest, WsErrorPath)
"api_version": 2 "api_version": 2
})JSON"; })JSON";
EXPECT_CALL(dosguard, isOk(session->clientIp)).WillOnce(testing::Return(true)); EXPECT_CALL(dosguard, isOk(session->clientIp())).WillOnce(testing::Return(true));
EXPECT_CALL(dosguard, request(session->clientIp, boost::json::parse(kREQUEST_JSON).as_object())) EXPECT_CALL(dosguard, request(session->clientIp(), boost::json::parse(kREQUEST_JSON).as_object()))
.WillOnce(testing::Return(true)); .WillOnce(testing::Return(true));
EXPECT_CALL(*rpcEngine, buildResponse(testing::_)) EXPECT_CALL(*rpcEngine, buildResponse(testing::_))
@@ -557,8 +557,8 @@ TEST_F(WebRPCServerHandlerTest, HTTPNotReady)
} }
})JSON"; })JSON";
EXPECT_CALL(dosguard, isOk(session->clientIp)).WillOnce(testing::Return(true)); EXPECT_CALL(dosguard, isOk(session->clientIp())).WillOnce(testing::Return(true));
EXPECT_CALL(dosguard, request(session->clientIp, boost::json::parse(kREQUEST).as_object())) EXPECT_CALL(dosguard, request(session->clientIp(), boost::json::parse(kREQUEST).as_object()))
.WillOnce(testing::Return(true)); .WillOnce(testing::Return(true));
EXPECT_CALL(*rpcEngine, notifyNotReady).Times(1); EXPECT_CALL(*rpcEngine, notifyNotReady).Times(1);
@@ -589,8 +589,8 @@ TEST_F(WebRPCServerHandlerTest, WsNotReady)
} }
})JSON"; })JSON";
EXPECT_CALL(dosguard, isOk(session->clientIp)).WillOnce(testing::Return(true)); EXPECT_CALL(dosguard, isOk(session->clientIp())).WillOnce(testing::Return(true));
EXPECT_CALL(dosguard, request(session->clientIp, boost::json::parse(kREQUEST).as_object())) EXPECT_CALL(dosguard, request(session->clientIp(), boost::json::parse(kREQUEST).as_object()))
.WillOnce(testing::Return(true)); .WillOnce(testing::Return(true));
EXPECT_CALL(*rpcEngine, notifyNotReady).Times(1); EXPECT_CALL(*rpcEngine, notifyNotReady).Times(1);
@@ -619,8 +619,8 @@ TEST_F(WebRPCServerHandlerTest, HTTPBadSyntaxWhenRequestSubscribe)
} }
})JSON"; })JSON";
EXPECT_CALL(dosguard, isOk(session->clientIp)).WillOnce(testing::Return(true)); EXPECT_CALL(dosguard, isOk(session->clientIp())).WillOnce(testing::Return(true));
EXPECT_CALL(dosguard, request(session->clientIp, testing::_)).WillOnce(testing::Return(true)); EXPECT_CALL(dosguard, request(session->clientIp(), testing::_)).WillOnce(testing::Return(true));
EXPECT_CALL(*rpcEngine, notifyBadSyntax).Times(1); EXPECT_CALL(*rpcEngine, notifyBadSyntax).Times(1);
@@ -636,8 +636,8 @@ TEST_F(WebRPCServerHandlerTest, HTTPMissingCommand)
static constexpr auto kRESPONSE = "Null method"; static constexpr auto kRESPONSE = "Null method";
EXPECT_CALL(dosguard, isOk(session->clientIp)).WillOnce(testing::Return(true)); EXPECT_CALL(dosguard, isOk(session->clientIp())).WillOnce(testing::Return(true));
EXPECT_CALL(dosguard, request(session->clientIp, testing::_)).WillOnce(testing::Return(true)); EXPECT_CALL(dosguard, request(session->clientIp(), testing::_)).WillOnce(testing::Return(true));
EXPECT_CALL(*rpcEngine, notifyBadSyntax).Times(1); EXPECT_CALL(*rpcEngine, notifyBadSyntax).Times(1);
@@ -654,8 +654,8 @@ TEST_F(WebRPCServerHandlerTest, HTTPCommandNotString)
static constexpr auto kRESPONSE = "method is not string"; static constexpr auto kRESPONSE = "method is not string";
EXPECT_CALL(dosguard, isOk(session->clientIp)).WillOnce(testing::Return(true)); EXPECT_CALL(dosguard, isOk(session->clientIp())).WillOnce(testing::Return(true));
EXPECT_CALL(dosguard, request(session->clientIp, testing::_)).WillOnce(testing::Return(true)); EXPECT_CALL(dosguard, request(session->clientIp(), testing::_)).WillOnce(testing::Return(true));
EXPECT_CALL(*rpcEngine, notifyBadSyntax).Times(1); EXPECT_CALL(*rpcEngine, notifyBadSyntax).Times(1);
@@ -672,8 +672,8 @@ TEST_F(WebRPCServerHandlerTest, HTTPCommandIsEmpty)
static constexpr auto kRESPONSE = "method is empty"; static constexpr auto kRESPONSE = "method is empty";
EXPECT_CALL(dosguard, isOk(session->clientIp)).WillOnce(testing::Return(true)); EXPECT_CALL(dosguard, isOk(session->clientIp())).WillOnce(testing::Return(true));
EXPECT_CALL(dosguard, request(session->clientIp, testing::_)).WillOnce(testing::Return(true)); EXPECT_CALL(dosguard, request(session->clientIp(), testing::_)).WillOnce(testing::Return(true));
EXPECT_CALL(*rpcEngine, notifyBadSyntax).Times(1); EXPECT_CALL(*rpcEngine, notifyBadSyntax).Times(1);
@@ -705,8 +705,8 @@ TEST_F(WebRPCServerHandlerTest, WsMissingCommand)
} }
})JSON"; })JSON";
EXPECT_CALL(dosguard, isOk(session->clientIp)).WillOnce(testing::Return(true)); EXPECT_CALL(dosguard, isOk(session->clientIp())).WillOnce(testing::Return(true));
EXPECT_CALL(dosguard, request(session->clientIp, boost::json::parse(kREQUEST).as_object())) EXPECT_CALL(dosguard, request(session->clientIp(), boost::json::parse(kREQUEST).as_object()))
.WillOnce(testing::Return(true)); .WillOnce(testing::Return(true));
EXPECT_CALL(*rpcEngine, notifyBadSyntax).Times(1); EXPECT_CALL(*rpcEngine, notifyBadSyntax).Times(1);
@@ -726,8 +726,8 @@ TEST_F(WebRPCServerHandlerTest, HTTPParamsUnparsableNotArray)
"params": "wrong" "params": "wrong"
})JSON"; })JSON";
EXPECT_CALL(dosguard, isOk(session->clientIp)).WillOnce(testing::Return(true)); EXPECT_CALL(dosguard, isOk(session->clientIp())).WillOnce(testing::Return(true));
EXPECT_CALL(dosguard, request(session->clientIp, testing::_)).WillOnce(testing::Return(true)); EXPECT_CALL(dosguard, request(session->clientIp(), testing::_)).WillOnce(testing::Return(true));
EXPECT_CALL(*rpcEngine, notifyBadSyntax).Times(1); EXPECT_CALL(*rpcEngine, notifyBadSyntax).Times(1);
@@ -747,8 +747,8 @@ TEST_F(WebRPCServerHandlerTest, HTTPParamsUnparsableArrayWithDigit)
"params": [1] "params": [1]
})JSON"; })JSON";
EXPECT_CALL(dosguard, isOk(session->clientIp)).WillOnce(testing::Return(true)); EXPECT_CALL(dosguard, isOk(session->clientIp())).WillOnce(testing::Return(true));
EXPECT_CALL(dosguard, request(session->clientIp, testing::_)).WillOnce(testing::Return(true)); EXPECT_CALL(dosguard, request(session->clientIp(), testing::_)).WillOnce(testing::Return(true));
EXPECT_CALL(*rpcEngine, notifyBadSyntax).Times(1); EXPECT_CALL(*rpcEngine, notifyBadSyntax).Times(1);
@@ -780,8 +780,8 @@ TEST_F(WebRPCServerHandlerTest, HTTPInternalError)
"params": [{}] "params": [{}]
})JSON"; })JSON";
EXPECT_CALL(dosguard, isOk(session->clientIp)).WillOnce(testing::Return(true)); EXPECT_CALL(dosguard, isOk(session->clientIp())).WillOnce(testing::Return(true));
EXPECT_CALL(dosguard, request(session->clientIp, boost::json::parse(kREQUEST_JSON).as_object())) EXPECT_CALL(dosguard, request(session->clientIp(), boost::json::parse(kREQUEST_JSON).as_object()))
.WillOnce(testing::Return(true)); .WillOnce(testing::Return(true));
EXPECT_CALL(*rpcEngine, notifyInternalError).Times(1); EXPECT_CALL(*rpcEngine, notifyInternalError).Times(1);
@@ -815,8 +815,8 @@ TEST_F(WebRPCServerHandlerTest, WsInternalError)
"id": "123" "id": "123"
})JSON"; })JSON";
EXPECT_CALL(dosguard, isOk(session->clientIp)).WillOnce(testing::Return(true)); EXPECT_CALL(dosguard, isOk(session->clientIp())).WillOnce(testing::Return(true));
EXPECT_CALL(dosguard, request(session->clientIp, boost::json::parse(kREQUEST_JSON).as_object())) EXPECT_CALL(dosguard, request(session->clientIp(), boost::json::parse(kREQUEST_JSON).as_object()))
.WillOnce(testing::Return(true)); .WillOnce(testing::Return(true));
EXPECT_CALL(*rpcEngine, notifyInternalError).Times(1); EXPECT_CALL(*rpcEngine, notifyInternalError).Times(1);
@@ -852,8 +852,8 @@ TEST_F(WebRPCServerHandlerTest, HTTPOutDated)
] ]
})JSON"; })JSON";
EXPECT_CALL(dosguard, isOk(session->clientIp)).WillOnce(testing::Return(true)); EXPECT_CALL(dosguard, isOk(session->clientIp())).WillOnce(testing::Return(true));
EXPECT_CALL(dosguard, request(session->clientIp, boost::json::parse(kREQUEST).as_object())) EXPECT_CALL(dosguard, request(session->clientIp(), boost::json::parse(kREQUEST).as_object()))
.WillOnce(testing::Return(true)); .WillOnce(testing::Return(true));
EXPECT_CALL(*rpcEngine, buildResponse(testing::_)) EXPECT_CALL(*rpcEngine, buildResponse(testing::_))
@@ -895,8 +895,8 @@ TEST_F(WebRPCServerHandlerTest, WsOutdated)
] ]
})JSON"; })JSON";
EXPECT_CALL(dosguard, isOk(session->clientIp)).WillOnce(testing::Return(true)); EXPECT_CALL(dosguard, isOk(session->clientIp())).WillOnce(testing::Return(true));
EXPECT_CALL(dosguard, request(session->clientIp, boost::json::parse(kREQUEST).as_object())) EXPECT_CALL(dosguard, request(session->clientIp(), boost::json::parse(kREQUEST).as_object()))
.WillOnce(testing::Return(true)); .WillOnce(testing::Return(true));
EXPECT_CALL(*rpcEngine, buildResponse(testing::_)) EXPECT_CALL(*rpcEngine, buildResponse(testing::_))
@@ -931,8 +931,8 @@ TEST_F(WebRPCServerHandlerTest, WsTooBusy)
"type": "response" "type": "response"
})JSON"; })JSON";
EXPECT_CALL(dosguard, isOk(session->clientIp)).WillOnce(testing::Return(true)); EXPECT_CALL(dosguard, isOk(session->clientIp())).WillOnce(testing::Return(true));
EXPECT_CALL(dosguard, request(session->clientIp, boost::json::parse(kREQUEST).as_object())) EXPECT_CALL(dosguard, request(session->clientIp(), boost::json::parse(kREQUEST).as_object()))
.WillOnce(testing::Return(true)); .WillOnce(testing::Return(true));
EXPECT_CALL(*localRpcEngine, notifyTooBusy).Times(1); EXPECT_CALL(*localRpcEngine, notifyTooBusy).Times(1);
@@ -962,8 +962,8 @@ TEST_F(WebRPCServerHandlerTest, HTTPTooBusy)
"type": "response" "type": "response"
})JSON"; })JSON";
EXPECT_CALL(dosguard, isOk(session->clientIp)).WillOnce(testing::Return(true)); EXPECT_CALL(dosguard, isOk(session->clientIp())).WillOnce(testing::Return(true));
EXPECT_CALL(dosguard, request(session->clientIp, boost::json::parse(kREQUEST).as_object())) EXPECT_CALL(dosguard, request(session->clientIp(), boost::json::parse(kREQUEST).as_object()))
.WillOnce(testing::Return(true)); .WillOnce(testing::Return(true));
EXPECT_CALL(*localRpcEngine, notifyTooBusy).Times(1); EXPECT_CALL(*localRpcEngine, notifyTooBusy).Times(1);
@@ -978,7 +978,7 @@ TEST_F(WebRPCServerHandlerTest, HTTPRequestNotJson)
static constexpr auto kREQUEST = "not json"; static constexpr auto kREQUEST = "not json";
static constexpr auto kRESPONSE_PREFIX = "Unable to parse JSON from the request"; static constexpr auto kRESPONSE_PREFIX = "Unable to parse JSON from the request";
EXPECT_CALL(dosguard, isOk(session->clientIp)).WillOnce(testing::Return(true)); EXPECT_CALL(dosguard, isOk(session->clientIp())).WillOnce(testing::Return(true));
EXPECT_CALL(*rpcEngine, notifyBadSyntax).Times(1); EXPECT_CALL(*rpcEngine, notifyBadSyntax).Times(1);
@@ -1000,7 +1000,7 @@ TEST_F(WebRPCServerHandlerTest, WsRequestNotJson)
"type": "response" "type": "response"
})JSON"; })JSON";
EXPECT_CALL(dosguard, isOk(session->clientIp)).WillOnce(testing::Return(true)); EXPECT_CALL(dosguard, isOk(session->clientIp())).WillOnce(testing::Return(true));
EXPECT_CALL(*rpcEngine, notifyBadSyntax).Times(1); EXPECT_CALL(*rpcEngine, notifyBadSyntax).Times(1);
@@ -1058,8 +1058,8 @@ TEST_P(WebRPCServerHandlerInvalidAPIVersionParamTest, HTTPInvalidAPIVersion)
backend_->setRange(kMIN_SEQ, kMAX_SEQ); backend_->setRange(kMIN_SEQ, kMAX_SEQ);
EXPECT_CALL(dosguard, isOk(session->clientIp)).WillOnce(testing::Return(true)); EXPECT_CALL(dosguard, isOk(session->clientIp())).WillOnce(testing::Return(true));
EXPECT_CALL(dosguard, request(session->clientIp, boost::json::parse(request).as_object())) EXPECT_CALL(dosguard, request(session->clientIp(), boost::json::parse(request).as_object()))
.WillOnce(testing::Return(true)); .WillOnce(testing::Return(true));
EXPECT_CALL(*rpcEngine, notifyBadSyntax).Times(1); EXPECT_CALL(*rpcEngine, notifyBadSyntax).Times(1);
@@ -1082,8 +1082,8 @@ TEST_P(WebRPCServerHandlerInvalidAPIVersionParamTest, WSInvalidAPIVersion)
backend_->setRange(kMIN_SEQ, kMAX_SEQ); backend_->setRange(kMIN_SEQ, kMAX_SEQ);
EXPECT_CALL(dosguard, isOk(session->clientIp)).WillOnce(testing::Return(true)); EXPECT_CALL(dosguard, isOk(session->clientIp())).WillOnce(testing::Return(true));
EXPECT_CALL(dosguard, request(session->clientIp, boost::json::parse(request).as_object())) EXPECT_CALL(dosguard, request(session->clientIp(), boost::json::parse(request).as_object()))
.WillOnce(testing::Return(true)); .WillOnce(testing::Return(true));
EXPECT_CALL(*rpcEngine, notifyBadSyntax).Times(1); EXPECT_CALL(*rpcEngine, notifyBadSyntax).Times(1);

View File

@@ -124,6 +124,8 @@ getParseServerConfig(boost::json::value val)
{"server.port", ConfigValue{ConfigType::Integer}}, {"server.port", ConfigValue{ConfigType::Integer}},
{"server.admin_password", ConfigValue{ConfigType::String}.optional()}, {"server.admin_password", ConfigValue{ConfigType::String}.optional()},
{"server.local_admin", ConfigValue{ConfigType::Boolean}.optional()}, {"server.local_admin", ConfigValue{ConfigType::Boolean}.optional()},
{"server.proxy.ips.[]", Array{ConfigValue{ConfigType::String}}},
{"server.proxy.tokens.[]", Array{ConfigValue{ConfigType::String}}},
{"server.ws_max_sending_queue_size", ConfigValue{ConfigType::Integer}.defaultValue(1500)}, {"server.ws_max_sending_queue_size", ConfigValue{ConfigType::Integer}.defaultValue(1500)},
{"log.tag_style", ConfigValue{ConfigType::String}.defaultValue("uint")}, {"log.tag_style", ConfigValue{ConfigType::String}.defaultValue("uint")},
{"dos_guard.max_fetches", ConfigValue{ConfigType::Integer}}, {"dos_guard.max_fetches", ConfigValue{ConfigType::Integer}},
@@ -519,6 +521,8 @@ getParseAdminServerConfig(boost::json::value val)
{"server.admin_password", ConfigValue{ConfigType::String}.optional()}, {"server.admin_password", ConfigValue{ConfigType::String}.optional()},
{"server.local_admin", ConfigValue{ConfigType::Boolean}.optional()}, {"server.local_admin", ConfigValue{ConfigType::Boolean}.optional()},
{"server.processing_policy", ConfigValue{ConfigType::String}.defaultValue("parallel")}, {"server.processing_policy", ConfigValue{ConfigType::String}.defaultValue("parallel")},
{"server.proxy.ips.[]", Array{ConfigValue{ConfigType::String}}},
{"server.proxy.tokens.[]", Array{ConfigValue{ConfigType::String}}},
{"server.parallel_requests_limit", ConfigValue{ConfigType::Integer}.optional()}, {"server.parallel_requests_limit", ConfigValue{ConfigType::Integer}.optional()},
{"server.ws_max_sending_queue_size", ConfigValue{ConfigType::Integer}.defaultValue(1500)}, {"server.ws_max_sending_queue_size", ConfigValue{ConfigType::Integer}.defaultValue(1500)},
{"ssl_cert_file", ConfigValue{ConfigType::String}.optional()}, {"ssl_cert_file", ConfigValue{ConfigType::String}.optional()},

View File

@@ -26,11 +26,13 @@
#include "util/Taggable.hpp" #include "util/Taggable.hpp"
#include "util/TestHttpClient.hpp" #include "util/TestHttpClient.hpp"
#include "util/TestWebSocketClient.hpp" #include "util/TestWebSocketClient.hpp"
#include "util/config/Array.hpp"
#include "util/config/ConfigConstraints.hpp" #include "util/config/ConfigConstraints.hpp"
#include "util/config/ConfigDefinition.hpp" #include "util/config/ConfigDefinition.hpp"
#include "util/config/ConfigFileJson.hpp" #include "util/config/ConfigFileJson.hpp"
#include "util/config/ConfigValue.hpp" #include "util/config/ConfigValue.hpp"
#include "util/config/Types.hpp" #include "util/config/Types.hpp"
#include "web/ProxyIpResolver.hpp"
#include "web/SubscriptionContextInterface.hpp" #include "web/SubscriptionContextInterface.hpp"
#include "web/ng/Connection.hpp" #include "web/ng/Connection.hpp"
#include "web/ng/ProcessingPolicy.hpp" #include "web/ng/ProcessingPolicy.hpp"
@@ -58,6 +60,7 @@
#include <optional> #include <optional>
#include <ranges> #include <ranges>
#include <string> #include <string>
#include <unordered_set>
using namespace web::ng; using namespace web::ng;
using namespace util::config; using namespace util::config;
@@ -85,6 +88,8 @@ TEST_P(MakeServerTest, Make)
{"server.ip", ConfigValue{ConfigType::String}.optional()}, {"server.ip", ConfigValue{ConfigType::String}.optional()},
{"server.port", ConfigValue{ConfigType::Integer}.optional()}, {"server.port", ConfigValue{ConfigType::Integer}.optional()},
{"server.processing_policy", ConfigValue{ConfigType::String}.defaultValue("parallel")}, {"server.processing_policy", ConfigValue{ConfigType::String}.defaultValue("parallel")},
{"server.proxy.ips.[]", Array{ConfigValue{ConfigType::String}}},
{"server.proxy.tokens.[]", Array{ConfigValue{ConfigType::String}}},
{"server.parallel_requests_limit", ConfigValue{ConfigType::Integer}.optional()}, {"server.parallel_requests_limit", ConfigValue{ConfigType::Integer}.optional()},
{"server.ws_max_sending_queue_size", ConfigValue{ConfigType::Integer}.defaultValue(1500)}, {"server.ws_max_sending_queue_size", ConfigValue{ConfigType::Integer}.defaultValue(1500)},
{"log.tag_style", ConfigValue{ConfigType::String}.defaultValue("uint")}, {"log.tag_style", ConfigValue{ConfigType::String}.defaultValue("uint")},
@@ -95,8 +100,13 @@ TEST_P(MakeServerTest, Make)
auto const errors = config.parse(json); auto const errors = config.parse(json);
ASSERT_TRUE(!errors.has_value()); ASSERT_TRUE(!errors.has_value());
auto const expectedServer = auto const expectedServer = makeServer(
makeServer(config, [](auto&&) -> std::expected<void, Response> { return {}; }, [](auto&&) {}, ioContext_); config,
[](auto&&) -> std::expected<void, Response> { return {}; },
[](auto&&, auto&&) {},
[](auto&&) {},
ioContext_
);
EXPECT_EQ(expectedServer.has_value(), GetParam().expectSuccess); EXPECT_EQ(expectedServer.has_value(), GetParam().expectSuccess);
} }
@@ -173,6 +183,8 @@ protected:
{"server.admin_password", ConfigValue{ConfigType::String}.optional()}, {"server.admin_password", ConfigValue{ConfigType::String}.optional()},
{"server.local_admin", ConfigValue{ConfigType::Boolean}.optional()}, {"server.local_admin", ConfigValue{ConfigType::Boolean}.optional()},
{"server.parallel_requests_limit", ConfigValue{ConfigType::Integer}.optional()}, {"server.parallel_requests_limit", ConfigValue{ConfigType::Integer}.optional()},
{"server.proxy.ips.[]", Array{ConfigValue{ConfigType::String}}},
{"server.proxy.tokens.[]", Array{ConfigValue{ConfigType::String}}},
{"server.ws_max_sending_queue_size", ConfigValue{ConfigType::Integer}.defaultValue(1500)}, {"server.ws_max_sending_queue_size", ConfigValue{ConfigType::Integer}.defaultValue(1500)},
{"log.tag_style", ConfigValue{ConfigType::String}.defaultValue("uint")}, {"log.tag_style", ConfigValue{ConfigType::String}.defaultValue("uint")},
{"ssl_key_file", ConfigValue{ConfigType::String}.optional()}, {"ssl_key_file", ConfigValue{ConfigType::String}.optional()},
@@ -180,7 +192,8 @@ protected:
}; };
Server::OnConnectCheck emptyOnConnectCheck_ = [](auto&&) -> std::expected<void, Response> { return {}; }; Server::OnConnectCheck emptyOnConnectCheck_ = [](auto&&) -> std::expected<void, Response> { return {}; };
std::expected<Server, std::string> server_ = makeServer(config_, emptyOnConnectCheck_, [](auto&&) {}, ctx_); std::expected<Server, std::string> server_ =
makeServer(config_, emptyOnConnectCheck_, [](auto&&, auto&&) {}, [](auto&&) {}, ctx_);
std::string requestMessage_ = "some request"; std::string requestMessage_ = "some request";
std::string const headerName_ = "Some-header"; std::string const headerName_ = "Some-header";
@@ -210,9 +223,13 @@ TEST_F(ServerTest, BadEndpoint)
ProcessingPolicy::Sequential, ProcessingPolicy::Sequential,
std::nullopt, std::nullopt,
tagDecoratorFactory, tagDecoratorFactory,
web::ProxyIpResolver{{}, {}},
std::nullopt, std::nullopt,
emptyOnConnectCheck_, Server::Hooks{
[](auto&&) {} .onConnectCheck = emptyOnConnectCheck_,
.onIpChangeHook = [](auto&&, auto&&) {},
.onDisconnectHook = [](auto&&) {}
}
}; };
auto maybeError = server.run(); auto maybeError = server.run();
@@ -274,9 +291,13 @@ TEST_F(ServerHttpTest, OnConnectCheck)
ProcessingPolicy::Sequential, ProcessingPolicy::Sequential,
std::nullopt, std::nullopt,
tagDecoratorFactory, tagDecoratorFactory,
web::ProxyIpResolver{{}, {}},
std::nullopt, std::nullopt,
onConnectCheck.AsStdFunction(), Server::Hooks{
[](auto&&) {} .onConnectCheck = onConnectCheck.AsStdFunction(),
.onIpChangeHook = [](auto&&, auto&&) {},
.onDisconnectHook = [](auto&&) {}
}
}; };
HttpAsyncClient client{ctx_}; HttpAsyncClient client{ctx_};
@@ -334,9 +355,13 @@ TEST_F(ServerHttpTest, OnConnectCheckFailed)
ProcessingPolicy::Sequential, ProcessingPolicy::Sequential,
std::nullopt, std::nullopt,
tagDecoratorFactory, tagDecoratorFactory,
web::ProxyIpResolver{{}, {}},
std::nullopt, std::nullopt,
onConnectCheck.AsStdFunction(), Server::Hooks{
[](auto&&) {} .onConnectCheck = onConnectCheck.AsStdFunction(),
.onIpChangeHook = [](auto&&, auto&&) {},
.onDisconnectHook = [](auto&&) {}
}
}; };
HttpAsyncClient client{ctx_}; HttpAsyncClient client{ctx_};
@@ -393,9 +418,13 @@ TEST_F(ServerHttpTest, OnDisconnectHook)
ProcessingPolicy::Sequential, ProcessingPolicy::Sequential,
std::nullopt, std::nullopt,
tagDecoratorFactory, tagDecoratorFactory,
web::ProxyIpResolver{{}, {}},
std::nullopt, std::nullopt,
emptyOnConnectCheck_, Server::Hooks{
onDisconnectHookMock.AsStdFunction() .onConnectCheck = emptyOnConnectCheck_,
.onIpChangeHook = [](auto&&, auto&&) {},
.onDisconnectHook = onDisconnectHookMock.AsStdFunction()
}
}; };
HttpAsyncClient client{ctx_}; HttpAsyncClient client{ctx_};

View File

@@ -25,6 +25,7 @@
#include "util/config/ConfigDefinition.hpp" #include "util/config/ConfigDefinition.hpp"
#include "util/config/ConfigValue.hpp" #include "util/config/ConfigValue.hpp"
#include "util/config/Types.hpp" #include "util/config/Types.hpp"
#include "web/ProxyIpResolver.hpp"
#include "web/SubscriptionContextInterface.hpp" #include "web/SubscriptionContextInterface.hpp"
#include "web/ng/Connection.hpp" #include "web/ng/Connection.hpp"
#include "web/ng/Error.hpp" #include "web/ng/Error.hpp"
@@ -40,11 +41,13 @@
#include <boost/asio/steady_timer.hpp> #include <boost/asio/steady_timer.hpp>
#include <boost/beast/core/flat_buffer.hpp> #include <boost/beast/core/flat_buffer.hpp>
#include <boost/beast/http/error.hpp> #include <boost/beast/http/error.hpp>
#include <boost/beast/http/field.hpp>
#include <boost/beast/http/message.hpp> #include <boost/beast/http/message.hpp>
#include <boost/beast/http/status.hpp> #include <boost/beast/http/status.hpp>
#include <boost/beast/http/string_body.hpp> #include <boost/beast/http/string_body.hpp>
#include <boost/beast/http/verb.hpp> #include <boost/beast/http/verb.hpp>
#include <boost/beast/websocket/error.hpp> #include <boost/beast/websocket/error.hpp>
#include <fmt/format.h>
#include <gmock/gmock.h> #include <gmock/gmock.h>
#include <gtest/gtest.h> #include <gtest/gtest.h>
@@ -60,7 +63,6 @@ using namespace web::ng::impl;
using namespace web::ng; using namespace web::ng;
using namespace util; using namespace util;
using testing::Return; using testing::Return;
namespace beast = boost::beast;
namespace http = boost::beast::http; namespace http = boost::beast::http;
namespace websocket = boost::beast::websocket; namespace websocket = boost::beast::websocket;
@@ -69,7 +71,15 @@ struct ConnectionHandlerTest : prometheus::WithPrometheus, SyncAsioContextTest {
: tagFactory{util::config::ClioConfigDefinition{ : tagFactory{util::config::ClioConfigDefinition{
{"log.tag_style", config::ConfigValue{config::ConfigType::String}.defaultValue("uint")} {"log.tag_style", config::ConfigValue{config::ConfigType::String}.defaultValue("uint")}
}} }}
, connectionHandler{policy, maxParallelConnections, tagFactory, std::nullopt, onDisconnectMock.AsStdFunction()} , connectionHandler{
policy,
maxParallelConnections,
tagFactory,
std::nullopt,
proxyIpResolver,
onDisconnectMock.AsStdFunction(),
onIpChangeMock.AsStdFunction()
}
{ {
} }
@@ -98,6 +108,12 @@ struct ConnectionHandlerTest : prometheus::WithPrometheus, SyncAsioContextTest {
return Request{std::forward<Args>(args)...}; return Request{std::forward<Args>(args)...};
} }
std::string const clientIp = "1.2.3.4";
std::string const proxyIp = "5.6.7.8";
std::string const proxyToken = "some_proxy_token";
web::ProxyIpResolver proxyIpResolver{{proxyIp}, {proxyToken}};
testing::StrictMock<testing::MockFunction<void(std::string const&, std::string const&)>> onIpChangeMock;
testing::StrictMock<testing::MockFunction<void(Connection const&)>> onDisconnectMock; testing::StrictMock<testing::MockFunction<void(Connection const&)>> onDisconnectMock;
util::TagDecoratorFactory tagFactory; util::TagDecoratorFactory tagFactory;
ConnectionHandler connectionHandler; ConnectionHandler connectionHandler;
@@ -106,9 +122,9 @@ struct ConnectionHandlerTest : prometheus::WithPrometheus, SyncAsioContextTest {
{"log.tag_style", config::ConfigValue{config::ConfigType::String}.defaultValue("uint")} {"log.tag_style", config::ConfigValue{config::ConfigType::String}.defaultValue("uint")}
}}; }};
StrictMockHttpConnectionPtr mockHttpConnection = StrictMockHttpConnectionPtr mockHttpConnection =
std::make_unique<StrictMockHttpConnection>("1.2.3.4", beast::flat_buffer{}, tagDecoratorFactory); std::make_unique<StrictMockHttpConnection>(clientIp, boost::beast::flat_buffer{}, tagDecoratorFactory);
StrictMockWsConnectionPtr mockWsConnection = StrictMockWsConnectionPtr mockWsConnection =
std::make_unique<StrictMockWsConnection>("1.2.3.4", beast::flat_buffer{}, tagDecoratorFactory); std::make_unique<StrictMockWsConnection>(clientIp, boost::beast::flat_buffer{}, tagDecoratorFactory);
Request::HttpHeaders headers; Request::HttpHeaders headers;
}; };
@@ -452,6 +468,50 @@ TEST_F(ConnectionHandlerSequentialProcessingTest, Receive_Handle_SendError)
}); });
} }
TEST_F(ConnectionHandlerSequentialProcessingTest, OnIpChangeHookCalledWhenSentFromProxy)
{
std::string const target = "/some/target";
testing::StrictMock<testing::MockFunction<
Response(Request const&, ConnectionMetadata const&, web::SubscriptionContextPtr, boost::asio::yield_context)>>
getHandlerMock;
std::string const requestMessage = "some message";
std::string const responseMessage = "some response";
connectionHandler.onGet(target, getHandlerMock.AsStdFunction());
StrictMockHttpConnectionPtr mockHttpConnectionFromProxy =
std::make_unique<StrictMockHttpConnection>(proxyIp, boost::beast::flat_buffer{}, tagDecoratorFactory);
auto request = http::request<http::string_body>{http::verb::get, target, 11, requestMessage};
request.set(http::field::forwarded, fmt::format("for={}", clientIp));
EXPECT_CALL(*mockHttpConnectionFromProxy, wasUpgraded).WillOnce(Return(false));
EXPECT_CALL(*mockHttpConnectionFromProxy, receive).WillOnce(Return(makeRequest(request)));
EXPECT_CALL(onIpChangeMock, Call(proxyIp, clientIp));
EXPECT_CALL(getHandlerMock, Call).WillOnce([&](Request const& request, auto&&, auto&&, auto&&) {
EXPECT_EQ(request.message(), requestMessage);
return Response(http::status::ok, responseMessage, request);
});
EXPECT_CALL(*mockHttpConnectionFromProxy, send).WillOnce([&responseMessage](Response response, auto&&) {
EXPECT_EQ(response.message(), responseMessage);
return makeError(http::error::end_of_stream).error();
});
EXPECT_CALL(onDisconnectMock, Call)
.WillOnce([this, connectionPtr = mockHttpConnectionFromProxy.get()](Connection const& c) {
EXPECT_EQ(&c, connectionPtr);
EXPECT_EQ(c.ip(), clientIp);
});
runSpawn([this, c = std::move(mockHttpConnectionFromProxy)](boost::asio::yield_context yield) mutable {
connectionHandler.processConnection(std::move(c), yield);
});
}
TEST_F(ConnectionHandlerSequentialProcessingTest, Stop) TEST_F(ConnectionHandlerSequentialProcessingTest, Stop)
{ {
testing::StrictMock<testing::MockFunction< testing::StrictMock<testing::MockFunction<
@@ -613,6 +673,48 @@ TEST_F(ConnectionHandlerParallelProcessingTest, Receive_Handle_Send)
}); });
} }
TEST_F(ConnectionHandlerParallelProcessingTest, OnIpChangeHookCalledWhenSentFromProxy)
{
testing::StrictMock<testing::MockFunction<
Response(Request const&, ConnectionMetadata const&, web::SubscriptionContextPtr, boost::asio::yield_context)>>
wsHandlerMock;
connectionHandler.onWs(wsHandlerMock.AsStdFunction());
StrictMockWsConnectionPtr mockWsConnectionFromProxy =
std::make_unique<StrictMockWsConnection>(proxyIp, boost::beast::flat_buffer{}, tagDecoratorFactory);
headers.set(http::field::forwarded, fmt::format("for={}", clientIp));
std::string const requestMessage = "some message";
std::string const responseMessage = "some response";
EXPECT_CALL(*mockWsConnectionFromProxy, wasUpgraded).WillOnce(Return(true));
EXPECT_CALL(*mockWsConnectionFromProxy, receive)
.WillOnce(Return(makeRequest(requestMessage, headers)))
.WillOnce(Return(makeError(websocket::error::closed)));
EXPECT_CALL(onIpChangeMock, Call(proxyIp, clientIp));
EXPECT_CALL(wsHandlerMock, Call).WillOnce([&](Request const& request, auto&&, auto&&, auto&&) {
EXPECT_EQ(request.message(), requestMessage);
return Response(http::status::ok, responseMessage, request);
});
EXPECT_CALL(*mockWsConnectionFromProxy, send).WillOnce([&responseMessage](Response response, auto&&) {
EXPECT_EQ(response.message(), responseMessage);
return std::nullopt;
});
EXPECT_CALL(onDisconnectMock, Call)
.WillOnce([this, connectionPtr = mockWsConnectionFromProxy.get()](Connection const& c) {
EXPECT_EQ(&c, connectionPtr);
EXPECT_EQ(c.ip(), clientIp);
});
runSpawn([this, c = std::move(mockWsConnectionFromProxy)](boost::asio::yield_context yield) mutable {
connectionHandler.processConnection(std::move(c), yield);
});
}
TEST_F(ConnectionHandlerParallelProcessingTest, Receive_Handle_Send_Loop) TEST_F(ConnectionHandlerParallelProcessingTest, Receive_Handle_Send_Loop)
{ {
testing::StrictMock<testing::MockFunction< testing::StrictMock<testing::MockFunction<

View File

@@ -38,6 +38,7 @@
#include <boost/beast/http/status.hpp> #include <boost/beast/http/status.hpp>
#include <boost/beast/http/string_body.hpp> #include <boost/beast/http/string_body.hpp>
#include <boost/beast/http/verb.hpp> #include <boost/beast/http/verb.hpp>
#include <fmt/format.h>
#include <gmock/gmock.h> #include <gmock/gmock.h>
#include <gtest/gtest.h> #include <gtest/gtest.h>

View File

@@ -39,7 +39,9 @@
#include <boost/asio/spawn.hpp> #include <boost/asio/spawn.hpp>
#include <boost/asio/steady_timer.hpp> #include <boost/asio/steady_timer.hpp>
#include <boost/beast/core/flat_buffer.hpp> #include <boost/beast/core/flat_buffer.hpp>
#include <boost/beast/http/field.hpp>
#include <boost/beast/http/status.hpp> #include <boost/beast/http/status.hpp>
#include <fmt/format.h>
#include <gmock/gmock.h> #include <gmock/gmock.h>
#include <gtest/gtest.h> #include <gtest/gtest.h>
@@ -65,7 +67,10 @@ struct WebWsConnectionTests : SyncAsioContextTest {
auto ip = expectedSocket->remote_endpoint().address().to_string(); auto ip = expectedSocket->remote_endpoint().address().to_string();
PlainHttpConnection httpConnection{ PlainHttpConnection httpConnection{
std::move(expectedSocket).value(), std::move(ip), boost::beast::flat_buffer{}, tagDecoratorFactory_ std::move(expectedSocket).value(),
std::move(ip),
boost::beast::flat_buffer{},
tagDecoratorFactory_,
}; };
auto expectedTrue = httpConnection.isUpgradeRequested(yield); auto expectedTrue = httpConnection.isUpgradeRequested(yield);