feat: Add dosguard to new webserver (#1772)

For #919.
This commit is contained in:
Sergey Kuznetsov
2024-12-10 14:59:12 +00:00
committed by GitHub
parent 4ff2953257
commit 7bef13f913
23 changed files with 1325 additions and 148 deletions

View File

@@ -1,4 +1,4 @@
add_library(clio_app)
target_sources(clio_app PRIVATE CliArgs.cpp ClioApplication.cpp)
target_sources(clio_app PRIVATE CliArgs.cpp ClioApplication.cpp WebHandlers.cpp)
target_link_libraries(clio_app PUBLIC clio_etl clio_etlng clio_feed clio_web clio_rpc)

View File

@@ -19,6 +19,7 @@
#include "app/ClioApplication.hpp"
#include "app/WebHandlers.hpp"
#include "data/AmendmentCenter.hpp"
#include "data/BackendFactory.hpp"
#include "etl/ETLService.hpp"
@@ -30,11 +31,9 @@
#include "rpc/RPCEngine.hpp"
#include "rpc/WorkQueue.hpp"
#include "rpc/common/impl/HandlerProvider.hpp"
#include "util/Assert.hpp"
#include "util/build/Build.hpp"
#include "util/config/Config.hpp"
#include "util/log/Logger.hpp"
#include "util/prometheus/Http.hpp"
#include "util/prometheus/Prometheus.hpp"
#include "web/AdminVerificationStrategy.hpp"
#include "web/RPCServerHandler.hpp"
@@ -52,11 +51,14 @@
#include <boost/asio/io_context.hpp>
#include <boost/asio/spawn.hpp>
#include <boost/beast/http/status.hpp>
#include <boost/json/array.hpp>
#include <boost/json/parse.hpp>
#include <cstdint>
#include <cstdlib>
#include <exception>
#include <memory>
#include <optional>
#include <thread>
#include <utility>
#include <vector>
@@ -65,14 +67,6 @@ namespace app {
namespace {
auto constexpr HealthCheckHTML = R"html(
<!DOCTYPE html>
<html>
<head><title>Test page for Clio</title></head>
<body><h1>Clio Test</h1><p>This page shows Clio http(s) connectivity is working.</p></body>
</html>
)html";
/**
* @brief Start context threads
*
@@ -158,72 +152,18 @@ ClioApplication::run(bool const useNgWebServer)
}
auto const adminVerifier = std::move(expectedAdminVerifier).value();
auto httpServer = web::ng::make_Server(config_, ioc);
auto httpServer = web::ng::make_Server(config_, OnConnectCheck{dosGuard}, DisconnectHook{dosGuard}, ioc);
if (not httpServer.has_value()) {
LOG(util::LogService::error()) << "Error creating web server: " << httpServer.error();
return EXIT_FAILURE;
}
httpServer->onGet(
"/metrics",
[adminVerifier](
web::ng::Request const& request,
web::ng::ConnectionMetadata& connectionMetadata,
web::SubscriptionContextPtr,
boost::asio::yield_context
) -> web::ng::Response {
auto const maybeHttpRequest = request.asHttpRequest();
ASSERT(maybeHttpRequest.has_value(), "Got not a http request in Get");
auto const& httpRequest = maybeHttpRequest->get();
// FIXME(#1702): Using veb server thread to handle prometheus request. Better to post on work queue.
auto maybeResponse = util::prometheus::handlePrometheusRequest(
httpRequest, adminVerifier->isAdmin(httpRequest, connectionMetadata.ip())
);
ASSERT(maybeResponse.has_value(), "Got unexpected request for Prometheus");
return web::ng::Response{std::move(maybeResponse).value(), request};
}
);
httpServer->onGet(
"/health",
[](web::ng::Request const& request,
web::ng::ConnectionMetadata&,
web::SubscriptionContextPtr,
boost::asio::yield_context) -> web::ng::Response {
return web::ng::Response{boost::beast::http::status::ok, HealthCheckHTML, request};
}
);
util::Logger webServerLog{"WebServer"};
auto onRequest = [adminVerifier, &webServerLog, &handler](
web::ng::Request const& request,
web::ng::ConnectionMetadata& connectionMetadata,
web::SubscriptionContextPtr subscriptionContext,
boost::asio::yield_context yield
) -> web::ng::Response {
LOG(webServerLog.info()) << connectionMetadata.tag()
<< "Received request from ip = " << connectionMetadata.ip()
<< " - posting to WorkQueue";
connectionMetadata.setIsAdmin([&adminVerifier, &request, &connectionMetadata]() {
return adminVerifier->isAdmin(request.httpHeaders(), connectionMetadata.ip());
});
try {
return handler(request, connectionMetadata, std::move(subscriptionContext), yield);
} catch (std::exception const&) {
return web::ng::Response{
boost::beast::http::status::internal_server_error,
rpc::makeError(rpc::RippledError::rpcINTERNAL),
request
};
}
};
httpServer->onPost("/", onRequest);
httpServer->onWs(onRequest);
httpServer->onGet("/metrics", MetricsHandler{adminVerifier});
httpServer->onGet("/health", HealthCheckHandler{});
auto requestHandler = RequestHandler{adminVerifier, handler, dosGuard};
httpServer->onPost("/", requestHandler);
httpServer->onWs(std::move(requestHandler));
auto const maybeError = httpServer->run();
if (maybeError.has_value()) {

111
src/app/WebHandlers.cpp Normal file
View File

@@ -0,0 +1,111 @@
//------------------------------------------------------------------------------
/*
This file is part of clio: https://github.com/XRPLF/clio
Copyright (c) 2024, the clio developers.
Permission to use, copy, modify, and distribute this software for any
purpose with or without fee is hereby granted, provided that the above
copyright notice and this permission notice appear in all copies.
THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
*/
//==============================================================================
#include "app/WebHandlers.hpp"
#include "util/Assert.hpp"
#include "util/prometheus/Http.hpp"
#include "web/AdminVerificationStrategy.hpp"
#include "web/SubscriptionContextInterface.hpp"
#include "web/dosguard/DOSGuardInterface.hpp"
#include "web/ng/Connection.hpp"
#include "web/ng/Request.hpp"
#include "web/ng/Response.hpp"
#include <boost/asio/spawn.hpp>
#include <boost/beast/http/status.hpp>
#include <memory>
#include <optional>
#include <utility>
namespace app {
OnConnectCheck::OnConnectCheck(web::dosguard::DOSGuardInterface& dosguard) : dosguard_{dosguard}
{
}
std::expected<void, web::ng::Response>
OnConnectCheck::operator()(web::ng::Connection const& connection)
{
dosguard_.get().increment(connection.ip());
if (not dosguard_.get().isOk(connection.ip())) {
return std::unexpected{
web::ng::Response{boost::beast::http::status::too_many_requests, "Too many requests", connection}
};
}
return {};
}
DisconnectHook::DisconnectHook(web::dosguard::DOSGuardInterface& dosguard) : dosguard_{dosguard}
{
}
void
DisconnectHook::operator()(web::ng::Connection const& connection)
{
dosguard_.get().decrement(connection.ip());
}
MetricsHandler::MetricsHandler(std::shared_ptr<web::AdminVerificationStrategy> adminVerifier)
: adminVerifier_{std::move(adminVerifier)}
{
}
web::ng::Response
MetricsHandler::operator()(
web::ng::Request const& request,
web::ng::ConnectionMetadata& connectionMetadata,
web::SubscriptionContextPtr,
boost::asio::yield_context
)
{
auto const maybeHttpRequest = request.asHttpRequest();
ASSERT(maybeHttpRequest.has_value(), "Got not a http request in Get");
auto const& httpRequest = maybeHttpRequest->get();
// FIXME(#1702): Using veb server thread to handle prometheus request. Better to post on work queue.
auto maybeResponse = util::prometheus::handlePrometheusRequest(
httpRequest, adminVerifier_->isAdmin(httpRequest, connectionMetadata.ip())
);
ASSERT(maybeResponse.has_value(), "Got unexpected request for Prometheus");
return web::ng::Response{std::move(maybeResponse).value(), request};
}
web::ng::Response
HealthCheckHandler::operator()(
web::ng::Request const& request,
web::ng::ConnectionMetadata&,
web::SubscriptionContextPtr,
boost::asio::yield_context
)
{
static auto constexpr HealthCheckHTML = R"html(
<!DOCTYPE html>
<html>
<head><title>Test page for Clio</title></head>
<body><h1>Clio Test</h1><p>This page shows Clio http(s) connectivity is working.</p></body>
</html>
)html";
return web::ng::Response{boost::beast::http::status::ok, HealthCheckHTML, request};
}
} // namespace app

234
src/app/WebHandlers.hpp Normal file
View File

@@ -0,0 +1,234 @@
//------------------------------------------------------------------------------
/*
This file is part of clio: https://github.com/XRPLF/clio
Copyright (c) 2024, the clio developers.
Permission to use, copy, modify, and distribute this software for any
purpose with or without fee is hereby granted, provided that the above
copyright notice and this permission notice appear in all copies.
THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
*/
//==============================================================================
#pragma once
#include "rpc/Errors.hpp"
#include "util/log/Logger.hpp"
#include "web/AdminVerificationStrategy.hpp"
#include "web/SubscriptionContextInterface.hpp"
#include "web/dosguard/DOSGuardInterface.hpp"
#include "web/ng/Connection.hpp"
#include "web/ng/Request.hpp"
#include "web/ng/Response.hpp"
#include <boost/asio/spawn.hpp>
#include <boost/beast/http/status.hpp>
#include <boost/json/array.hpp>
#include <boost/json/parse.hpp>
#include <exception>
#include <functional>
#include <memory>
#include <utility>
namespace app {
/**
* @brief A function object that checks if the connection is allowed to proceed.
*/
class OnConnectCheck {
std::reference_wrapper<web::dosguard::DOSGuardInterface> dosguard_;
public:
/**
* @brief Construct a new OnConnectCheck object
*
* @param dosguard The DOSGuardInterface to use for checking the connection.
*/
OnConnectCheck(web::dosguard::DOSGuardInterface& dosguard);
/**
* @brief Check if the connection is allowed to proceed.
*
* @param connection The connection to check.
* @return A response if the connection is not allowed to proceed or void otherwise.
*/
std::expected<void, web::ng::Response>
operator()(web::ng::Connection const& connection);
};
/**
* @brief A function object to be called when a connection is disconnected.
*/
class DisconnectHook {
std::reference_wrapper<web::dosguard::DOSGuardInterface> dosguard_;
public:
/**
* @brief Construct a new DisconnectHook object
*
* @param dosguard The DOSGuardInterface to use for disconnecting the connection.
*/
DisconnectHook(web::dosguard::DOSGuardInterface& dosguard);
/**
* @brief The call of the function object.
*
* @param connection The connection which has disconnected.
*/
void
operator()(web::ng::Connection const& connection);
};
/**
* @brief A function object that handles the metrics endpoint.
*/
class MetricsHandler {
std::shared_ptr<web::AdminVerificationStrategy> adminVerifier_;
public:
/**
* @brief Construct a new MetricsHandler object
*
* @param adminVerifier The AdminVerificationStrategy to use for verifying the connection for admin access.
*/
MetricsHandler(std::shared_ptr<web::AdminVerificationStrategy> adminVerifier);
/**
* @brief The call of the function object.
*
* @param request The request to handle.
* @param connectionMetadata The connection metadata.
* @return The response to the request.
*/
web::ng::Response
operator()(
web::ng::Request const& request,
web::ng::ConnectionMetadata& connectionMetadata,
web::SubscriptionContextPtr,
boost::asio::yield_context
);
};
/**
* @brief A function object that handles the health check endpoint.
*/
class HealthCheckHandler {
public:
/**
* @brief The call of the function object.
*
* @param request The request to handle.
* @return The response to the request
*/
web::ng::Response
operator()(
web::ng::Request const& request,
web::ng::ConnectionMetadata&,
web::SubscriptionContextPtr,
boost::asio::yield_context
);
};
/**
* @brief A function object that handles the websocket endpoint.
*
* @tparam RpcHandlerType The type of the RPC handler.
*/
template <typename RpcHandlerType>
class RequestHandler {
util::Logger webServerLog_{"WebServer"};
std::shared_ptr<web::AdminVerificationStrategy> adminVerifier_;
std::reference_wrapper<RpcHandlerType> rpcHandler_;
std::reference_wrapper<web::dosguard::DOSGuardInterface> dosguard_;
public:
/**
* @brief Construct a new RequestHandler object
*
* @param adminVerifier The AdminVerificationStrategy to use for verifying the connection for admin access.
* @param rpcHandler The RPC handler to use for handling the request.
* @param dosguard The DOSGuardInterface to use for checking the connection.
*/
RequestHandler(
std::shared_ptr<web::AdminVerificationStrategy> adminVerifier,
RpcHandlerType& rpcHandler,
web::dosguard::DOSGuardInterface& dosguard
)
: adminVerifier_(std::move(adminVerifier)), rpcHandler_(rpcHandler), dosguard_(dosguard)
{
}
/**
* @brief The call of the function object.
*
* @param request The request to handle.
* @param connectionMetadata The connection metadata.
* @param subscriptionContext The subscription context.
* @param yield The yield context.
* @return The response to the request.
*/
web::ng::Response
operator()(
web::ng::Request const& request,
web::ng::ConnectionMetadata& connectionMetadata,
web::SubscriptionContextPtr subscriptionContext,
boost::asio::yield_context yield
)
{
if (not dosguard_.get().request(connectionMetadata.ip())) {
auto error = rpc::makeError(rpc::RippledError::rpcSLOW_DOWN);
if (not request.isHttp()) {
try {
auto requestJson = boost::json::parse(request.message());
if (requestJson.is_object() && requestJson.as_object().contains("id"))
error["id"] = requestJson.as_object().at("id");
error["request"] = request.message();
} catch (std::exception const&) {
error["request"] = request.message();
}
}
return web::ng::Response{boost::beast::http::status::service_unavailable, error, request};
}
LOG(webServerLog_.info()) << connectionMetadata.tag()
<< "Received request from ip = " << connectionMetadata.ip()
<< " - posting to WorkQueue";
connectionMetadata.setIsAdmin([this, &request, &connectionMetadata]() {
return adminVerifier_->isAdmin(request.httpHeaders(), connectionMetadata.ip());
});
try {
auto response = rpcHandler_(request, connectionMetadata, std::move(subscriptionContext), yield);
if (not dosguard_.get().add(connectionMetadata.ip(), response.message().size())) {
auto jsonResponse = boost::json::parse(response.message()).as_object();
jsonResponse["warning"] = "load";
if (jsonResponse.contains("warnings") && jsonResponse["warnings"].is_array()) {
jsonResponse["warnings"].as_array().push_back(rpc::makeWarning(rpc::warnRPC_RATE_LIMIT));
} else {
jsonResponse["warnings"] = boost::json::array{rpc::makeWarning(rpc::warnRPC_RATE_LIMIT)};
}
response.setMessage(jsonResponse);
}
return response;
} catch (std::exception const&) {
return web::ng::Response{
boost::beast::http::status::internal_server_error,
rpc::makeError(rpc::RippledError::rpcINTERNAL),
request
};
}
}
};
} // namespace app

View File

@@ -22,6 +22,7 @@
#include "util/Assert.hpp"
#include "util/OverloadSet.hpp"
#include "util/build/Build.hpp"
#include "web/ng/Connection.hpp"
#include "web/ng/Request.hpp"
#include <boost/asio/buffer.hpp>
@@ -33,6 +34,7 @@
#include <boost/json/serialize.hpp>
#include <fmt/core.h>
#include <cstdint>
#include <optional>
#include <string>
#include <type_traits>
@@ -45,42 +47,63 @@ namespace web::ng {
namespace {
template <typename T>
consteval bool
isString()
struct MessageData {
template <typename MessageType>
MessageData(MessageType message)
{
if constexpr (std::is_same_v<MessageType, std::string>) {
body = std::move(message);
contentType = "text/html";
} else {
body = boost::json::serialize(message);
contentType = "application/json";
}
}
std::string body;
std::string contentType;
};
http::response<http::string_body>
prepareResponse(http::response<http::string_body> response, bool keepAlive)
{
return std::is_same_v<T, std::string>;
response.set(http::field::server, fmt::format("clio-server-{}", util::build::getClioVersionString()));
response.keep_alive(keepAlive);
response.prepare_payload();
return response;
}
http::response<http::string_body>
prepareResponse(http::response<http::string_body> response, http::request<http::string_body> const& request)
makeHttpData(MessageData messageData, http::status status, uint16_t httpVersion, bool keepAlive)
{
response.set(http::field::server, fmt::format("clio-server-{}", util::build::getClioVersionString()));
response.keep_alive(request.keep_alive());
response.prepare_payload();
return response;
http::response<http::string_body> result{status, httpVersion, std::move(messageData.body)};
result.set(http::field::content_type, messageData.contentType);
return prepareResponse(std::move(result), keepAlive);
}
template <typename MessageType>
std::variant<http::response<http::string_body>, std::string>
makeData(http::status status, MessageType message, Request const& request)
{
std::string body;
if constexpr (isString<MessageType>()) {
body = std::move(message);
} else {
body = boost::json::serialize(message);
}
MessageData messageData{std::move(message)};
if (not request.isHttp())
return body;
return std::move(messageData).body;
auto const& httpRequest = request.asHttpRequest()->get();
std::string const contentType = isString<MessageType>() ? "text/html" : "application/json";
return makeHttpData(std::move(messageData), status, httpRequest.version(), httpRequest.keep_alive());
}
http::response<http::string_body> result{status, httpRequest.version(), std::move(body)};
result.set(http::field::content_type, contentType);
return prepareResponse(std::move(result), httpRequest);
template <typename MessageType>
std::variant<http::response<http::string_body>, std::string>
makeData(http::status status, MessageType message, Connection const& connection)
{
MessageData messageData{std::move(message)};
if (connection.wasUpgraded())
return std::move(messageData).body;
return makeHttpData(std::move(messageData), status, 11, false);
}
} // namespace
@@ -95,10 +118,20 @@ Response::Response(boost::beast::http::status status, boost::json::object const&
{
}
Response::Response(boost::beast::http::status status, boost::json::object const& message, Connection const& connection)
: data_{makeData(status, message, connection)}
{
}
Response::Response(boost::beast::http::status status, std::string message, Connection const& connection)
: data_{makeData(status, std::move(message), connection)}
{
}
Response::Response(boost::beast::http::response<boost::beast::http::string_body> response, Request const& request)
{
ASSERT(request.isHttp(), "Request must be HTTP to construct response from HTTP response");
data_ = prepareResponse(std::move(response), request.asHttpRequest()->get());
data_ = prepareResponse(std::move(response), request.asHttpRequest()->get().keep_alive());
}
std::string const&
@@ -115,6 +148,34 @@ Response::message() const
);
}
void
Response::setMessage(std::string newMessage)
{
if (std::holds_alternative<std::string>(data_)) {
std::get<std::string>(data_) = std::move(newMessage);
return;
}
MessageData messageData{std::move(newMessage)};
auto const& oldHttpResponse = std::get<http::response<http::string_body>>(data_);
data_ = makeHttpData(
std::move(messageData), oldHttpResponse.result(), oldHttpResponse.version(), oldHttpResponse.keep_alive()
);
}
void
Response::setMessage(boost::json::object const& newMessage)
{
MessageData messageData{newMessage};
if (std::holds_alternative<std::string>(data_)) {
std::get<std::string>(data_) = std::move(messageData).body;
return;
}
auto const& oldHttpResponse = std::get<http::response<http::string_body>>(data_);
data_ = makeHttpData(
std::move(messageData), oldHttpResponse.result(), oldHttpResponse.version(), oldHttpResponse.keep_alive()
);
}
http::response<http::string_body>
Response::intoHttpResponse() &&
{

View File

@@ -32,6 +32,8 @@
namespace web::ng {
class Connection;
/**
* @brief Represents an HTTP or Websocket response.
*/
@@ -60,6 +62,26 @@ public:
*/
Response(boost::beast::http::status status, boost::json::object const& message, Request const& request);
/**
* @brief Construct a Response from string. Content type will be text/html.
*
* @param status The HTTP status.
* @param message The message to send.
* @param connection The connection that triggered this response. Used to determine whether the response should
* contain HTTP or WebSocket data.
*/
Response(boost::beast::http::status status, boost::json::object const& message, Connection const& connection);
/**
* @brief Construct a Response from string. Content type will be text/html.
*
* @param status The HTTP status.
* @param message The message to send.
* @param connection The connection that triggered this response. Used to determine whether the response should
* contain HTTP or WebSocket data.
*/
Response(boost::beast::http::status status, std::string message, Connection const& connection);
/**
* @brief Construct a Response from HTTP response.
*
@@ -76,6 +98,22 @@ public:
std::string const&
message() const;
/**
* @brief Replace existing message (or body) with new message.
*
* @param newMessage The new message.
*/
void
setMessage(std::string newMessage);
/**
* @brief Replace existing message (or body) with new message.
*
* @param newMessage The new message.
*/
void
setMessage(boost::json::object const& newMessage);
/**
* @brief Convert the Response to an HTTP response.
* @note The Response must be constructed with an HTTP request.

View File

@@ -124,12 +124,13 @@ detectSsl(boost::asio::ip::tcp::socket socket, boost::asio::yield_context yield)
return SslDetectionResult{.socket = tcpStream.release_socket(), .isSsl = isSsl, .buffer = std::move(buffer)};
}
std::expected<ConnectionPtr, std::string>
std::expected<ConnectionPtr, std::optional<std::string>>
makeConnection(
SslDetectionResult sslDetectionResult,
std::optional<boost::asio::ssl::context>& sslContext,
std::string ip,
util::TagDecoratorFactory& tagDecoratorFactory,
Server::OnConnectCheck onConnectCheck,
boost::asio::yield_context yield
)
{
@@ -154,6 +155,13 @@ makeConnection(
);
}
auto expectedSuccess = onConnectCheck(*connection);
if (not expectedSuccess.has_value()) {
connection->send(std::move(expectedSuccess).error(), yield);
connection->close(yield);
return std::unexpected{std::nullopt};
}
auto const expectedIsUpgrade = connection->isUpgradeRequested(yield);
if (not expectedIsUpgrade.has_value()) {
return std::unexpected{
@@ -182,13 +190,16 @@ Server::Server(
ProcessingPolicy processingPolicy,
std::optional<size_t> parallelRequestLimit,
util::TagDecoratorFactory tagDecoratorFactory,
std::optional<size_t> maxSubscriptionSendQueueSize
std::optional<size_t> maxSubscriptionSendQueueSize,
OnConnectCheck onConnectCheck,
OnDisconnectHook onDisconnectHook
)
: ctx_{ctx}
, sslContext_{std::move(sslContext)}
, tagDecoratorFactory_{tagDecoratorFactory}
, connectionHandler_{processingPolicy, parallelRequestLimit, tagDecoratorFactory_, maxSubscriptionSendQueueSize}
, connectionHandler_{processingPolicy, parallelRequestLimit, tagDecoratorFactory_, maxSubscriptionSendQueueSize, std::move(onDisconnectHook)}
, endpoint_{std::move(endpoint)}
, onConnectCheck_{std::move(onConnectCheck)}
{
}
@@ -269,13 +280,18 @@ Server::handleConnection(boost::asio::ip::tcp::socket socket, boost::asio::yield
return;
}
// TODO(kuznetsss): check ip with dosguard here
auto connectionExpected = makeConnection(
std::move(sslDetectionResult).value(), sslContext_, std::move(ip).value(), tagDecoratorFactory_, yield
std::move(sslDetectionResult).value(),
sslContext_,
std::move(ip).value(),
tagDecoratorFactory_,
onConnectCheck_,
yield
);
if (not connectionExpected.has_value()) {
LOG(log_.info()) << "Error creating a connection: " << connectionExpected.error();
if (connectionExpected.error().has_value()) {
LOG(log_.info()) << "Error creating a connection: " << *connectionExpected.error();
}
return;
}
@@ -288,7 +304,12 @@ Server::handleConnection(boost::asio::ip::tcp::socket socket, boost::asio::yield
}
std::expected<Server, std::string>
make_Server(util::Config const& config, boost::asio::io_context& context)
make_Server(
util::Config const& config,
Server::OnConnectCheck onConnectCheck,
Server::OnDisconnectHook onDisconnectHook,
boost::asio::io_context& context
)
{
auto const serverConfig = config.section("server");
@@ -321,7 +342,9 @@ make_Server(util::Config const& config, boost::asio::io_context& context)
processingPolicy,
parallelRequestLimit,
util::TagDecoratorFactory(config),
maxSubscriptionSendQueueSize
maxSubscriptionSendQueueSize,
std::move(onConnectCheck),
std::move(onDisconnectHook)
};
}

View File

@@ -22,8 +22,10 @@
#include "util/Taggable.hpp"
#include "util/config/Config.hpp"
#include "util/log/Logger.hpp"
#include "web/ng/Connection.hpp"
#include "web/ng/MessageHandler.hpp"
#include "web/ng/ProcessingPolicy.hpp"
#include "web/ng/Response.hpp"
#include "web/ng/impl/ConnectionHandler.hpp"
#include <boost/asio/io_context.hpp>
@@ -42,6 +44,19 @@ namespace web::ng {
* @brief Web server class.
*/
class Server {
public:
/**
* @brief Check to perform for each new client connection. The check takes client ip as input and returns a Response
* if the check failed. Response will be sent to the client and the connection will be closed.
*/
using OnConnectCheck = std::function<std::expected<void, Response>(Connection const&)>;
/**
* @brief Hook called when any connection disconnects
*/
using OnDisconnectHook = impl::ConnectionHandler::OnDisconnectHook;
private:
util::Logger log_{"WebServer"};
util::Logger perfLog_{"Performance"};
@@ -53,6 +68,8 @@ class Server {
impl::ConnectionHandler connectionHandler_;
boost::asio::ip::tcp::endpoint endpoint_;
OnConnectCheck onConnectCheck_;
bool running_{false};
public:
@@ -67,6 +84,8 @@ public:
* if processingPolicy is parallel.
* @param tagDecoratorFactory The tag decorator factory.
* @param maxSubscriptionSendQueueSize The maximum size of the subscription send queue.
* @param onConnectCheck The check to perform on each connection.
* @param onDisconnectHook The hook to call on each disconnection.
*/
Server(
boost::asio::io_context& ctx,
@@ -75,7 +94,9 @@ public:
ProcessingPolicy processingPolicy,
std::optional<size_t> parallelRequestLimit,
util::TagDecoratorFactory tagDecoratorFactory,
std::optional<size_t> maxSubscriptionSendQueueSize
std::optional<size_t> maxSubscriptionSendQueueSize,
OnConnectCheck onConnectCheck,
OnDisconnectHook onDisconnectHook
);
/**
@@ -141,11 +162,18 @@ private:
* @brief Create a new Server.
*
* @param config The configuration.
* @param onConnectCheck The check to perform on each client connection.
* @param onDisconnectHook The hook to call when client disconnects.
* @param context The boost::asio::io_context to use.
*
* @return The Server or an error message.
*/
std::expected<Server, std::string>
make_Server(util::Config const& config, boost::asio::io_context& context);
make_Server(
util::Config const& config,
Server::OnConnectCheck onConnectCheck,
Server::OnDisconnectHook onDisconnectHook,
boost::asio::io_context& context
);
} // namespace web::ng

View File

@@ -106,12 +106,14 @@ ConnectionHandler::ConnectionHandler(
ProcessingPolicy processingPolicy,
std::optional<size_t> maxParallelRequests,
util::TagDecoratorFactory& tagFactory,
std::optional<size_t> maxSubscriptionSendQueueSize
std::optional<size_t> maxSubscriptionSendQueueSize,
OnDisconnectHook onDisconnectHook
)
: processingPolicy_{processingPolicy}
, maxParallelRequests_{maxParallelRequests}
, tagFactory_{tagFactory}
, maxSubscriptionSendQueueSize_{maxSubscriptionSendQueueSize}
, onDisconnectHook_{std::move(onDisconnectHook)}
{
}
@@ -171,6 +173,7 @@ ConnectionHandler::processConnection(ConnectionPtr connectionPtr, boost::asio::y
connectionRef.close(yield);
signalConnection.disconnect();
onDisconnectHook_(connectionRef);
}
void

View File

@@ -44,6 +44,8 @@ namespace web::ng::impl {
class ConnectionHandler {
public:
using OnDisconnectHook = std::function<void(Connection const&)>;
struct StringHash {
using hash_type = std::hash<std::string_view>;
using is_transparent = void;
@@ -68,6 +70,8 @@ private:
std::reference_wrapper<util::TagDecoratorFactory> tagFactory_;
std::optional<size_t> maxSubscriptionSendQueueSize_;
OnDisconnectHook onDisconnectHook_;
TargetToHandlerMap getHandlers_;
TargetToHandlerMap postHandlers_;
std::optional<MessageHandler> wsHandler_;
@@ -79,7 +83,8 @@ public:
ProcessingPolicy processingPolicy,
std::optional<size_t> maxParallelRequests,
util::TagDecoratorFactory& tagFactory,
std::optional<size_t> maxSubscriptionSendQueueSize
std::optional<size_t> maxSubscriptionSendQueueSize,
OnDisconnectHook onDisconnectHook
);
void

View File

@@ -65,6 +65,13 @@ public:
util::TagDecoratorFactory const& tagDecoratorFactory,
boost::asio::yield_context yield
) = 0;
virtual std::optional<Error>
sendRaw(
boost::beast::http::response<boost::beast::http::string_body> response,
boost::asio::yield_context yield,
std::chrono::steady_clock::duration timeout = DEFAULT_TIMEOUT
) = 0;
};
using UpgradableConnectionPtr = std::unique_ptr<UpgradableConnection>;
@@ -105,6 +112,21 @@ public:
return false;
}
std::optional<Error>
sendRaw(
boost::beast::http::response<boost::beast::http::string_body> response,
boost::asio::yield_context yield,
std::chrono::steady_clock::duration timeout = DEFAULT_TIMEOUT
) override
{
boost::system::error_code error;
boost::beast::get_lowest_layer(stream_).expires_after(timeout);
boost::beast::http::async_write(stream_, response, yield[error]);
if (error)
return error;
return std::nullopt;
}
std::optional<Error>
send(
Response response,
@@ -112,13 +134,8 @@ public:
std::chrono::steady_clock::duration timeout = DEFAULT_TIMEOUT
) override
{
auto const httpResponse = std::move(response).intoHttpResponse();
boost::system::error_code error;
boost::beast::get_lowest_layer(stream_).expires_after(timeout);
boost::beast::http::async_write(stream_, httpResponse, yield[error]);
if (error)
return error;
return std::nullopt;
auto httpResponse = std::move(response).intoHttpResponse();
return sendRaw(std::move(httpResponse), yield, timeout);
}
std::expected<Request, Error>

View File

@@ -0,0 +1,41 @@
//------------------------------------------------------------------------------
/*
This file is part of clio: https://github.com/XRPLF/clio
Copyright (c) 2024, the clio developers.
Permission to use, copy, modify, and distribute this software for any
purpose with or without fee is hereby granted, provided that the above
copyright notice and this permission notice appear in all copies.
THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
*/
//==============================================================================
#pragma once
#include "web/dosguard/DOSGuardInterface.hpp"
#include <gmock/gmock.h>
#include <cstdint>
#include <string>
#include <string_view>
struct DOSGuardMockImpl : web::dosguard::DOSGuardInterface {
MOCK_METHOD(bool, isWhiteListed, (std::string_view const ip), (const, noexcept, override));
MOCK_METHOD(bool, isOk, (std::string const& ip), (const, noexcept, override));
MOCK_METHOD(void, increment, (std::string const& ip), (noexcept, override));
MOCK_METHOD(void, decrement, (std::string const& ip), (noexcept, override));
MOCK_METHOD(bool, add, (std::string const& ip, uint32_t size), (noexcept, override));
MOCK_METHOD(bool, request, (std::string const& ip), (noexcept, override));
MOCK_METHOD(void, clear, (), (noexcept, override));
};
using DOSGuardMock = testing::NiceMock<DOSGuardMockImpl>;
using DOSGuardStrictMock = testing::StrictMock<DOSGuardMockImpl>;

View File

@@ -19,7 +19,6 @@
#pragma once
#include "util/Taggable.hpp"
#include "web/ng/Connection.hpp"
#include "web/ng/Error.hpp"
#include "web/ng/Request.hpp"

View File

@@ -29,6 +29,8 @@
#include <boost/asio/spawn.hpp>
#include <boost/asio/ssl/context.hpp>
#include <boost/beast/http/message.hpp>
#include <boost/beast/http/string_body.hpp>
#include <gmock/gmock.h>
#include <chrono>
@@ -48,6 +50,15 @@ struct MockHttpConnectionImpl : web::ng::impl::UpgradableConnection {
(override)
);
MOCK_METHOD(
SendReturnType,
sendRaw,
(boost::beast::http::response<boost::beast::http::string_body>,
boost::asio::yield_context,
std::chrono::steady_clock::duration),
(override)
);
using ReceiveReturnType = std::expected<web::ng::Request, web::ng::Error>;
MOCK_METHOD(
ReceiveReturnType,

View File

@@ -5,6 +5,7 @@ target_sources(
PRIVATE # Common
ConfigTests.cpp
app/CliArgsTests.cpp
app/WebHandlersTests.cpp
data/AmendmentCenterTests.cpp
data/BackendCountersTests.cpp
data/BackendInterfaceTests.cpp

View File

@@ -0,0 +1,321 @@
//------------------------------------------------------------------------------
/*
This file is part of clio: https://github.com/XRPLF/clio
Copyright (c) 2024, the clio developers.
Permission to use, copy, modify, and distribute this software for any
purpose with or without fee is hereby granted, provided that the above
copyright notice and this permission notice appear in all copies.
THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
*/
//==============================================================================
#include "app/WebHandlers.hpp"
#include "util/AsioContextTestFixture.hpp"
#include "util/LoggerFixtures.hpp"
#include "util/MockPrometheus.hpp"
#include "util/Taggable.hpp"
#include "util/config/Config.hpp"
#include "web/AdminVerificationStrategy.hpp"
#include "web/SubscriptionContextInterface.hpp"
#include "web/dosguard/DOSGuardMock.hpp"
#include "web/ng/Connection.hpp"
#include "web/ng/MockConnection.hpp"
#include "web/ng/Request.hpp"
#include "web/ng/Response.hpp"
#include <boost/asio/spawn.hpp>
#include <boost/beast/core/flat_buffer.hpp>
#include <boost/beast/http/message.hpp>
#include <boost/beast/http/status.hpp>
#include <boost/beast/http/string_body.hpp>
#include <boost/beast/http/verb.hpp>
#include <boost/json/parse.hpp>
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include <memory>
#include <optional>
#include <stdexcept>
#include <string>
#include <string_view>
#include <utility>
using namespace app;
namespace http = boost::beast::http;
struct WebHandlersTest : virtual NoLoggerFixture {
DOSGuardStrictMock dosGuardMock_;
util::TagDecoratorFactory tagFactory_{util::Config{}};
std::string const ip_ = "some ip";
StrictMockConnection connectionMock_{ip_, boost::beast::flat_buffer{}, tagFactory_};
struct AdminVerificationStrategyMock : web::AdminVerificationStrategy {
MOCK_METHOD(bool, isAdmin, (RequestHeader const&, std::string_view), (const, override));
};
using AdminVerificationStrategyStrictMockPtr = std::shared_ptr<testing::StrictMock<AdminVerificationStrategyMock>>;
};
struct OnConnectCheckTests : WebHandlersTest {
OnConnectCheck onConnectCheck_{dosGuardMock_};
};
TEST_F(OnConnectCheckTests, Ok)
{
EXPECT_CALL(dosGuardMock_, increment(ip_));
EXPECT_CALL(dosGuardMock_, isOk(ip_)).WillOnce(testing::Return(true));
EXPECT_TRUE(onConnectCheck_(connectionMock_).has_value());
}
TEST_F(OnConnectCheckTests, RateLimited)
{
EXPECT_CALL(dosGuardMock_, increment(ip_));
EXPECT_CALL(dosGuardMock_, isOk(ip_)).WillOnce(testing::Return(false));
EXPECT_CALL(connectionMock_, wasUpgraded).WillOnce(testing::Return(false));
auto response = onConnectCheck_(connectionMock_);
ASSERT_FALSE(response.has_value());
auto const httpResponse = std::move(response).error().intoHttpResponse();
EXPECT_EQ(httpResponse.result(), boost::beast::http::status::too_many_requests);
EXPECT_EQ(httpResponse.body(), "Too many requests");
}
struct DisconnectHookTests : WebHandlersTest {
DisconnectHook disconnectHook_{dosGuardMock_};
};
TEST_F(DisconnectHookTests, CallsDecrement)
{
EXPECT_CALL(dosGuardMock_, decrement(ip_));
disconnectHook_(connectionMock_);
}
struct MetricsHandlerTests : util::prometheus::WithPrometheus, SyncAsioContextTest, WebHandlersTest {
AdminVerificationStrategyStrictMockPtr adminVerifier_{
std::make_shared<testing::StrictMock<AdminVerificationStrategyMock>>()
};
MetricsHandler metricsHandler_{adminVerifier_};
web::ng::Request request_{http::request<http::string_body>{http::verb::get, "/metrics", 11}};
};
TEST_F(MetricsHandlerTests, Call)
{
EXPECT_CALL(*adminVerifier_, isAdmin).WillOnce(testing::Return(true));
runSpawn([&](boost::asio::yield_context yield) {
auto response = metricsHandler_(request_, connectionMock_, nullptr, yield);
auto const httpResponse = std::move(response).intoHttpResponse();
EXPECT_EQ(httpResponse.result(), boost::beast::http::status::ok);
});
}
struct HealthCheckHandlerTests : SyncAsioContextTest, WebHandlersTest {
web::ng::Request request_{http::request<http::string_body>{http::verb::get, "/", 11}};
HealthCheckHandler healthCheckHandler_;
};
TEST_F(HealthCheckHandlerTests, Call)
{
runSpawn([&](boost::asio::yield_context yield) {
auto response = healthCheckHandler_(request_, connectionMock_, nullptr, yield);
auto const httpResponse = std::move(response).intoHttpResponse();
EXPECT_EQ(httpResponse.result(), boost::beast::http::status::ok);
});
}
struct RequestHandlerTest : SyncAsioContextTest, WebHandlersTest {
AdminVerificationStrategyStrictMockPtr adminVerifier_{
std::make_shared<testing::StrictMock<AdminVerificationStrategyMock>>()
};
struct RpcHandlerMock {
MOCK_METHOD(
web::ng::Response,
call,
(web::ng::Request const&,
web::ng::ConnectionMetadata const&,
web::SubscriptionContextPtr,
boost::asio::yield_context),
()
);
web::ng::Response
operator()(
web::ng::Request const& request,
web::ng::ConnectionMetadata const& connectionMetadata,
web::SubscriptionContextPtr subscriptionContext,
boost::asio::yield_context yield
)
{
return call(request, connectionMetadata, std::move(subscriptionContext), yield);
}
};
testing::StrictMock<RpcHandlerMock> rpcHandler_;
StrictMockConnection connectionMock_{ip_, boost::beast::flat_buffer{}, tagFactory_};
RequestHandler<RpcHandlerMock> requestHandler_{adminVerifier_, rpcHandler_, dosGuardMock_};
};
TEST_F(RequestHandlerTest, DosguardRateLimited_Http)
{
web::ng::Request const request{http::request<http::string_body>{http::verb::get, "/", 11}};
EXPECT_CALL(dosGuardMock_, request(ip_)).WillOnce(testing::Return(false));
runSpawn([&](boost::asio::yield_context yield) {
auto response = requestHandler_(request, connectionMock_, nullptr, yield);
auto const httpResponse = std::move(response).intoHttpResponse();
EXPECT_EQ(httpResponse.result(), boost::beast::http::status::service_unavailable);
auto const body = boost::json::parse(httpResponse.body()).as_object();
EXPECT_EQ(body.at("error").as_string(), "slowDown");
EXPECT_EQ(body.at("error_code").as_int64(), 10);
EXPECT_EQ(body.at("status").as_string(), "error");
EXPECT_FALSE(body.contains("id"));
EXPECT_FALSE(body.contains("request"));
});
}
TEST_F(RequestHandlerTest, DosguardRateLimited_Ws)
{
auto const requestMessage = R"json({"some": "request", "id": "some id"})json";
web::ng::Request::HttpHeaders const headers{};
web::ng::Request const request{requestMessage, headers};
EXPECT_CALL(dosGuardMock_, request(ip_)).WillOnce(testing::Return(false));
runSpawn([&](boost::asio::yield_context yield) {
auto const response = requestHandler_(request, connectionMock_, nullptr, yield);
auto const message = boost::json::parse(response.message()).as_object();
EXPECT_EQ(message.at("error").as_string(), "slowDown");
EXPECT_EQ(message.at("error_code").as_int64(), 10);
EXPECT_EQ(message.at("status").as_string(), "error");
EXPECT_EQ(message.at("id").as_string(), "some id");
EXPECT_EQ(message.at("request").as_string(), requestMessage);
});
}
TEST_F(RequestHandlerTest, DosguardRateLimited_Ws_ErrorParsing)
{
auto const requestMessage = R"json(some request "id": "some id")json";
web::ng::Request::HttpHeaders const headers{};
web::ng::Request const request{requestMessage, headers};
EXPECT_CALL(dosGuardMock_, request(ip_)).WillOnce(testing::Return(false));
runSpawn([&](boost::asio::yield_context yield) {
auto const response = requestHandler_(request, connectionMock_, nullptr, yield);
auto const message = boost::json::parse(response.message()).as_object();
EXPECT_EQ(message.at("error").as_string(), "slowDown");
EXPECT_EQ(message.at("error_code").as_int64(), 10);
EXPECT_EQ(message.at("status").as_string(), "error");
EXPECT_FALSE(message.contains("id"));
EXPECT_EQ(message.at("request").as_string(), requestMessage);
});
}
TEST_F(RequestHandlerTest, RpcHandlerThrows)
{
web::ng::Request const request{http::request<http::string_body>{http::verb::get, "/", 11}};
EXPECT_CALL(dosGuardMock_, request(ip_)).WillOnce(testing::Return(true));
EXPECT_CALL(*adminVerifier_, isAdmin).WillOnce(testing::Return(true));
EXPECT_CALL(rpcHandler_, call).WillOnce(testing::Throw(std::runtime_error{"some error"}));
runSpawn([&](boost::asio::yield_context yield) {
auto response = requestHandler_(request, connectionMock_, nullptr, yield);
auto const httpResponse = std::move(response).intoHttpResponse();
EXPECT_EQ(httpResponse.result(), boost::beast::http::status::internal_server_error);
auto const body = boost::json::parse(httpResponse.body()).as_object();
EXPECT_EQ(body.at("error").as_string(), "internal");
EXPECT_EQ(body.at("error_code").as_int64(), 73);
EXPECT_EQ(body.at("status").as_string(), "error");
});
}
TEST_F(RequestHandlerTest, NoErrors)
{
web::ng::Request const request{http::request<http::string_body>{http::verb::get, "/", 11}};
web::ng::Response const response{http::status::ok, "some response", request};
auto const httpResponse = web::ng::Response{response}.intoHttpResponse();
EXPECT_CALL(dosGuardMock_, request(ip_)).WillOnce(testing::Return(true));
EXPECT_CALL(*adminVerifier_, isAdmin).WillOnce(testing::Return(true));
EXPECT_CALL(rpcHandler_, call).WillOnce(testing::Return(response));
EXPECT_CALL(dosGuardMock_, add(ip_, testing::_)).WillOnce(testing::Return(true));
runSpawn([&](boost::asio::yield_context yield) {
auto actualResponse = requestHandler_(request, connectionMock_, nullptr, yield);
auto const actualHttpResponse = std::move(actualResponse).intoHttpResponse();
EXPECT_EQ(actualHttpResponse.result(), httpResponse.result());
EXPECT_EQ(actualHttpResponse.body(), httpResponse.body());
EXPECT_EQ(actualHttpResponse.version(), 11);
});
}
TEST_F(RequestHandlerTest, ResponseDosGuardWarning_ResponseHasWarnings)
{
web::ng::Request const request{http::request<http::string_body>{http::verb::get, "/", 11}};
web::ng::Response const response{
http::status::ok, R"json({"some":"response", "warnings":["some warning"]})json", request
};
auto const httpResponse = web::ng::Response{response}.intoHttpResponse();
EXPECT_CALL(dosGuardMock_, request(ip_)).WillOnce(testing::Return(true));
EXPECT_CALL(*adminVerifier_, isAdmin).WillOnce(testing::Return(true));
EXPECT_CALL(rpcHandler_, call).WillOnce(testing::Return(response));
EXPECT_CALL(dosGuardMock_, add(ip_, testing::_)).WillOnce(testing::Return(false));
runSpawn([&](boost::asio::yield_context yield) {
auto actualResponse = requestHandler_(request, connectionMock_, nullptr, yield);
auto const actualHttpResponse = std::move(actualResponse).intoHttpResponse();
EXPECT_EQ(actualHttpResponse.result(), httpResponse.result());
EXPECT_EQ(actualHttpResponse.version(), 11);
auto actualBody = boost::json::parse(actualHttpResponse.body()).as_object();
EXPECT_EQ(actualBody.at("some").as_string(), "response");
EXPECT_EQ(actualBody.at("warnings").as_array().size(), 2);
});
}
TEST_F(RequestHandlerTest, ResponseDosGuardWarning_ResponseDoesntHaveWarnings)
{
web::ng::Request const request{http::request<http::string_body>{http::verb::get, "/", 11}};
web::ng::Response const response{http::status::ok, R"json({"some":"response"})json", request};
auto const httpResponse = web::ng::Response{response}.intoHttpResponse();
EXPECT_CALL(dosGuardMock_, request(ip_)).WillOnce(testing::Return(true));
EXPECT_CALL(*adminVerifier_, isAdmin).WillOnce(testing::Return(true));
EXPECT_CALL(rpcHandler_, call).WillOnce(testing::Return(response));
EXPECT_CALL(dosGuardMock_, add(ip_, testing::_)).WillOnce(testing::Return(false));
runSpawn([&](boost::asio::yield_context yield) {
auto actualResponse = requestHandler_(request, connectionMock_, nullptr, yield);
auto const actualHttpResponse = std::move(actualResponse).intoHttpResponse();
EXPECT_EQ(actualHttpResponse.result(), httpResponse.result());
EXPECT_EQ(actualHttpResponse.version(), 11);
auto actualBody = boost::json::parse(actualHttpResponse.body()).as_object();
EXPECT_EQ(actualBody.at("some").as_string(), "response");
EXPECT_EQ(actualBody.at("warnings").as_array().size(), 1);
});
}

View File

@@ -19,7 +19,7 @@
#include "util/AsioContextTestFixture.hpp"
#include "util/config/Config.hpp"
#include "web/dosguard/DOSGuardInterface.hpp"
#include "web/dosguard/DOSGuardMock.hpp"
#include "web/dosguard/IntervalSweepHandler.hpp"
#include <boost/json/parse.hpp>
@@ -40,10 +40,7 @@ protected:
}
)JSON";
struct DosGuardMock : BaseDOSGuard {
MOCK_METHOD(void, clear, (), (noexcept, override));
};
testing::StrictMock<DosGuardMock> guardMock;
DOSGuardStrictMock guardMock;
util::Config cfg{boost::json::parse(JSONData)};
IntervalSweepHandler sweepHandler{cfg, ctx, guardMock};

View File

@@ -34,7 +34,10 @@
using namespace web::ng;
namespace http = boost::beast::http;
struct RequestTest : public ::testing::Test {};
struct RequestTest : public ::testing::Test {
static Request::HttpHeaders const headers_;
};
Request::HttpHeaders const RequestTest::headers_ = {};
struct RequestMethodTestBundle {
std::string testName;
@@ -65,7 +68,7 @@ INSTANTIATE_TEST_SUITE_P(
},
RequestMethodTestBundle{
.testName = "WebSocket",
.request = Request{"websocket message", Request::HttpHeaders{}},
.request = Request{"websocket message", RequestTest::headers_},
.expectedMethod = Request::Method::Websocket,
},
RequestMethodTestBundle{
@@ -101,7 +104,7 @@ INSTANTIATE_TEST_SUITE_P(
},
RequestIsHttpTestBundle{
.testName = "WebSocketRequest",
.request = Request{"websocket message", Request::HttpHeaders{}},
.request = Request{"websocket message", RequestTest::headers_},
.expectedIsHttp = false,
}
),
@@ -124,7 +127,7 @@ TEST_F(RequestAsHttpRequestTest, HttpRequest)
TEST_F(RequestAsHttpRequestTest, WebSocketRequest)
{
Request const request{"websocket message", Request::HttpHeaders{}};
Request const request{"websocket message", RequestTest::headers_};
auto const maybeHttpRequest = request.asHttpRequest();
EXPECT_FALSE(maybeHttpRequest.has_value());
}
@@ -142,7 +145,7 @@ TEST_F(RequestMessageTest, HttpRequest)
TEST_F(RequestMessageTest, WebSocketRequest)
{
std::string const message = "websocket message";
Request const request{message, Request::HttpHeaders{}};
Request const request{message, RequestTest::headers_};
EXPECT_EQ(request.message(), message);
}
@@ -171,7 +174,7 @@ INSTANTIATE_TEST_SUITE_P(
},
RequestTargetTestBundle{
.testName = "WebSocketRequest",
.request = Request{"websocket message", Request::HttpHeaders{}},
.request = Request{"websocket message", RequestTest::headers_},
.expectedTarget = std::nullopt,
}
),

View File

@@ -17,10 +17,14 @@
*/
//==============================================================================
#include "util/Taggable.hpp"
#include "util/build/Build.hpp"
#include "util/config/Config.hpp"
#include "web/ng/MockConnection.hpp"
#include "web/ng/Request.hpp"
#include "web/ng/Response.hpp"
#include <boost/beast/core/flat_buffer.hpp>
#include <boost/beast/http/field.hpp>
#include <boost/beast/http/message.hpp>
#include <boost/beast/http/status.hpp>
@@ -29,6 +33,7 @@
#include <boost/json/object.hpp>
#include <boost/json/serialize.hpp>
#include <fmt/core.h>
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include <string>
@@ -41,21 +46,23 @@ struct ResponseDeathTest : testing::Test {};
TEST_F(ResponseDeathTest, intoHttpResponseWithoutHttpData)
{
Request const request{"some messsage", Request::HttpHeaders{}};
web::ng::Response response{boost::beast::http::status::ok, "message", request};
Request::HttpHeaders const headers{};
Request const request{"some message", headers};
Response response{boost::beast::http::status::ok, "message", request};
EXPECT_DEATH(std::move(response).intoHttpResponse(), "");
}
TEST_F(ResponseDeathTest, asConstBufferWithHttpData)
{
Request const request{http::request<http::string_body>{http::verb::get, "/", 11}};
web::ng::Response const response{boost::beast::http::status::ok, "message", request};
Response const response{boost::beast::http::status::ok, "message", request};
EXPECT_DEATH(response.asWsResponse(), "");
}
struct ResponseTest : testing::Test {
int const httpVersion_ = 11;
http::status const responseStatus_ = http::status::ok;
Request::HttpHeaders const headers_;
};
TEST_F(ResponseTest, intoHttpResponse)
@@ -63,7 +70,7 @@ TEST_F(ResponseTest, intoHttpResponse)
Request const request{http::request<http::string_body>{http::verb::post, "/", httpVersion_, "some message"}};
std::string const responseMessage = "response message";
web::ng::Response response{responseStatus_, responseMessage, request};
Response response{responseStatus_, responseMessage, request};
auto const httpResponse = std::move(response).intoHttpResponse();
EXPECT_EQ(httpResponse.result(), responseStatus_);
@@ -83,7 +90,7 @@ TEST_F(ResponseTest, intoHttpResponseJson)
Request const request{http::request<http::string_body>{http::verb::post, "/", httpVersion_, "some message"}};
boost::json::object const responseMessage{{"key", "value"}};
web::ng::Response response{responseStatus_, responseMessage, request};
Response response{responseStatus_, responseMessage, request};
auto const httpResponse = std::move(response).intoHttpResponse();
EXPECT_EQ(httpResponse.result(), responseStatus_);
@@ -100,9 +107,9 @@ TEST_F(ResponseTest, intoHttpResponseJson)
TEST_F(ResponseTest, asConstBuffer)
{
Request const request("some request", Request::HttpHeaders{});
Request const request("some request", headers_);
std::string const responseMessage = "response message";
web::ng::Response const response{responseStatus_, responseMessage, request};
Response const response{responseStatus_, responseMessage, request};
auto const buffer = response.asWsResponse();
EXPECT_EQ(buffer.size(), responseMessage.size());
@@ -113,9 +120,9 @@ TEST_F(ResponseTest, asConstBuffer)
TEST_F(ResponseTest, asConstBufferJson)
{
Request const request("some request", Request::HttpHeaders{});
Request const request("some request", headers_);
boost::json::object const responseMessage{{"key", "value"}};
web::ng::Response const response{responseStatus_, responseMessage, request};
Response const response{responseStatus_, responseMessage, request};
auto const buffer = response.asWsResponse();
EXPECT_EQ(buffer.size(), boost::json::serialize(responseMessage).size());
@@ -123,3 +130,88 @@ TEST_F(ResponseTest, asConstBufferJson)
std::string const messageFromBuffer{static_cast<char const*>(buffer.data()), buffer.size()};
EXPECT_EQ(messageFromBuffer, boost::json::serialize(responseMessage));
}
TEST_F(ResponseTest, createFromStringAndConnection)
{
util::TagDecoratorFactory tagDecoratorFactory{util::Config{}};
StrictMockConnection connection{"some ip", boost::beast::flat_buffer{}, tagDecoratorFactory};
std::string const responseMessage = "response message";
EXPECT_CALL(connection, wasUpgraded()).WillOnce(testing::Return(false));
Response response{responseStatus_, responseMessage, connection};
EXPECT_EQ(response.message(), responseMessage);
auto const httpResponse = std::move(response).intoHttpResponse();
EXPECT_EQ(httpResponse.result(), responseStatus_);
auto const it = httpResponse.find(http::field::content_type);
ASSERT_NE(it, httpResponse.end());
EXPECT_EQ(it->value(), "text/html");
}
TEST_F(ResponseTest, createFromJsonAndConnection)
{
util::TagDecoratorFactory tagDecoratorFactory{util::Config{}};
StrictMockConnection connection{"some ip", boost::beast::flat_buffer{}, tagDecoratorFactory};
boost::json::object const responseMessage{{"key", "value"}};
EXPECT_CALL(connection, wasUpgraded()).WillOnce(testing::Return(false));
Response response{responseStatus_, responseMessage, connection};
EXPECT_EQ(response.message(), boost::json::serialize(responseMessage));
auto const httpResponse = std::move(response).intoHttpResponse();
EXPECT_EQ(httpResponse.result(), responseStatus_);
auto const it = httpResponse.find(http::field::content_type);
ASSERT_NE(it, httpResponse.end());
EXPECT_EQ(it->value(), "application/json");
}
TEST_F(ResponseTest, setMessageString_HttpResponse)
{
Request const request{http::request<http::string_body>{http::verb::post, "/", httpVersion_, "some request"}};
Response response{boost::beast::http::status::ok, "message", request};
std::string const newMessage = "new message";
response.setMessage(newMessage);
EXPECT_EQ(response.message(), newMessage);
auto const httpResponse = std::move(response).intoHttpResponse();
auto it = httpResponse.find(http::field::content_type);
ASSERT_NE(it, httpResponse.end());
EXPECT_EQ(it->value(), "text/html");
}
TEST_F(ResponseTest, setMessageString_WsResponse)
{
Request const request{"some request", headers_};
Response response{boost::beast::http::status::ok, "message", request};
std::string const newMessage = "new message";
response.setMessage(newMessage);
EXPECT_EQ(response.message(), newMessage);
}
TEST_F(ResponseTest, setMessageJson_HttpResponse)
{
Request const request{http::request<http::string_body>{http::verb::post, "/", httpVersion_, "some request"}};
Response response{boost::beast::http::status::ok, "message", request};
boost::json::object const newMessage{{"key", "value"}};
response.setMessage(newMessage);
auto const httpResponse = std::move(response).intoHttpResponse();
auto it = httpResponse.find(http::field::content_type);
ASSERT_NE(it, httpResponse.end());
EXPECT_EQ(it->value(), "application/json");
}
TEST_F(ResponseTest, setMessageJson_WsResponse)
{
Request const request{"some request", headers_};
Response response{boost::beast::http::status::ok, "message", request};
boost::json::object const newMessage{{"key", "value"}};
response.setMessage(newMessage);
EXPECT_EQ(response.message(), boost::json::serialize(newMessage));
}

View File

@@ -36,6 +36,7 @@
#include <boost/asio/ip/address_v4.hpp>
#include <boost/asio/ip/tcp.hpp>
#include <boost/asio/spawn.hpp>
#include <boost/asio/steady_timer.hpp>
#include <boost/beast/http/message.hpp>
#include <boost/beast/http/status.hpp>
#include <boost/beast/http/string_body.hpp>
@@ -68,7 +69,8 @@ struct MakeServerTest : NoLoggerFixture, testing::WithParamInterface<MakeServerT
TEST_P(MakeServerTest, Make)
{
util::Config const config{boost::json::parse(GetParam().configJson)};
auto const expectedServer = make_Server(config, ioContext_);
auto const expectedServer =
make_Server(config, [](auto&&) -> std::expected<void, Response> { return {}; }, [](auto&&) {}, ioContext_);
EXPECT_EQ(expectedServer.has_value(), GetParam().expectSuccess);
}
@@ -159,7 +161,9 @@ struct ServerTest : SyncAsioContextTest {
boost::json::object{{"server", boost::json::object{{"ip", "127.0.0.1"}, {"port", serverPort_}}}}
};
std::expected<Server, std::string> server_ = make_Server(config_, ctx);
Server::OnConnectCheck emptyOnConnectCheck_ = [](auto&&) -> std::expected<void, Response> { return {}; };
std::expected<Server, std::string> server_ = make_Server(config_, emptyOnConnectCheck_, [](auto&&) {}, ctx);
std::string requestMessage_ = "some request";
std::string const headerName_ = "Some-header";
@@ -181,8 +185,17 @@ TEST_F(ServerTest, BadEndpoint)
boost::asio::ip::tcp::endpoint const endpoint{boost::asio::ip::address_v4::from_string("1.2.3.4"), 0};
util::TagDecoratorFactory const tagDecoratorFactory{util::Config{boost::json::value{}}};
Server server{
ctx, endpoint, std::nullopt, ProcessingPolicy::Sequential, std::nullopt, tagDecoratorFactory, std::nullopt
ctx,
endpoint,
std::nullopt,
ProcessingPolicy::Sequential,
std::nullopt,
tagDecoratorFactory,
std::nullopt,
emptyOnConnectCheck_,
[](auto&&) {}
};
auto maybeError = server.run();
ASSERT_TRUE(maybeError.has_value());
EXPECT_THAT(*maybeError, testing::HasSubstr("Error creating TCP acceptor"));
@@ -224,6 +237,170 @@ TEST_F(ServerHttpTest, ClientDisconnects)
runContext();
}
TEST_F(ServerHttpTest, OnConnectCheck)
{
auto const serverPort = tests::util::generateFreePort();
boost::asio::ip::tcp::endpoint const endpoint{boost::asio::ip::address_v4::from_string("0.0.0.0"), serverPort};
util::TagDecoratorFactory const tagDecoratorFactory{util::Config{boost::json::value{}}};
testing::StrictMock<testing::MockFunction<std::expected<void, Response>(Connection const&)>> onConnectCheck;
Server server{
ctx,
endpoint,
std::nullopt,
ProcessingPolicy::Sequential,
std::nullopt,
tagDecoratorFactory,
std::nullopt,
onConnectCheck.AsStdFunction(),
[](auto&&) {}
};
HttpAsyncClient client{ctx};
boost::asio::spawn(ctx, [&](boost::asio::yield_context yield) {
boost::asio::steady_timer timer{yield.get_executor()};
EXPECT_CALL(onConnectCheck, Call)
.WillOnce([&timer](Connection const& connection) -> std::expected<void, Response> {
EXPECT_EQ(connection.ip(), "127.0.0.1");
timer.cancel();
return {};
});
auto maybeError =
client.connect("127.0.0.1", std::to_string(serverPort), yield, std::chrono::milliseconds{100});
[&]() { ASSERT_FALSE(maybeError.has_value()) << maybeError->message(); }();
// Have to send a request here because the server does async_detect_ssl() which waits for some data to appear
client.send(
http::request<http::string_body>{http::verb::get, "/", 11, requestMessage_},
yield,
std::chrono::milliseconds{100}
);
// Wait for the onConnectCheck to be called
timer.expires_after(std::chrono::milliseconds{100});
boost::system::error_code error; // Unused
timer.async_wait(yield[error]);
client.gracefulShutdown();
ctx.stop();
});
server.run();
runContext();
}
TEST_F(ServerHttpTest, OnConnectCheckFailed)
{
auto const serverPort = tests::util::generateFreePort();
boost::asio::ip::tcp::endpoint const endpoint{boost::asio::ip::address_v4::from_string("0.0.0.0"), serverPort};
util::TagDecoratorFactory const tagDecoratorFactory{util::Config{boost::json::value{}}};
testing::StrictMock<testing::MockFunction<std::expected<void, Response>(Connection const&)>> onConnectCheck;
Server server{
ctx,
endpoint,
std::nullopt,
ProcessingPolicy::Sequential,
std::nullopt,
tagDecoratorFactory,
std::nullopt,
onConnectCheck.AsStdFunction(),
[](auto&&) {}
};
HttpAsyncClient client{ctx};
EXPECT_CALL(onConnectCheck, Call).WillOnce([](Connection const& connection) {
EXPECT_EQ(connection.ip(), "127.0.0.1");
return std::unexpected{
Response{http::status::too_many_requests, boost::json::object{{"error", "some error"}}, connection}
};
});
boost::asio::spawn(ctx, [&](boost::asio::yield_context yield) {
auto maybeError =
client.connect("127.0.0.1", std::to_string(serverPort), yield, std::chrono::milliseconds{100});
[&]() { ASSERT_FALSE(maybeError.has_value()) << maybeError->message(); }();
// Have to send a request here because the server does async_detect_ssl() which waits for some data to appear
client.send(
http::request<http::string_body>{http::verb::get, "/", 11, requestMessage_},
yield,
std::chrono::milliseconds{100}
);
auto const response = client.receive(yield, std::chrono::milliseconds{100});
[&]() { ASSERT_TRUE(response.has_value()) << response.error().message(); }();
EXPECT_EQ(response->result(), http::status::too_many_requests);
EXPECT_EQ(response->body(), R"json({"error":"some error"})json");
EXPECT_EQ(response->version(), 11);
client.gracefulShutdown();
ctx.stop();
});
server.run();
runContext();
}
TEST_F(ServerHttpTest, OnDisconnectHook)
{
auto const serverPort = tests::util::generateFreePort();
boost::asio::ip::tcp::endpoint const endpoint{boost::asio::ip::address_v4::from_string("0.0.0.0"), serverPort};
util::TagDecoratorFactory const tagDecoratorFactory{util::Config{boost::json::value{}}};
testing::StrictMock<testing::MockFunction<void(Connection const&)>> OnDisconnectHookMock;
Server server{
ctx,
endpoint,
std::nullopt,
ProcessingPolicy::Sequential,
std::nullopt,
tagDecoratorFactory,
std::nullopt,
emptyOnConnectCheck_,
OnDisconnectHookMock.AsStdFunction()
};
HttpAsyncClient client{ctx};
boost::asio::spawn(ctx, [&](boost::asio::yield_context yield) {
boost::asio::steady_timer timer{ctx.get_executor(), std::chrono::milliseconds{100}};
EXPECT_CALL(OnDisconnectHookMock, Call).WillOnce([&timer](auto&&) { timer.cancel(); });
auto maybeError =
client.connect("127.0.0.1", std::to_string(serverPort), yield, std::chrono::milliseconds{100});
[&]() { ASSERT_FALSE(maybeError.has_value()) << maybeError->message(); }();
client.send(
http::request<http::string_body>{http::verb::get, "/", 11, requestMessage_},
yield,
std::chrono::milliseconds{100}
);
client.gracefulShutdown();
// Wait for OnDisconnectHook is called
boost::system::error_code error;
timer.async_wait(yield[error]);
ctx.stop();
});
server.run();
runContext();
}
TEST_P(ServerHttpTest, RequestResponse)
{
HttpAsyncClient client{ctx};
@@ -300,7 +477,8 @@ TEST_F(ServerTest, WsRequestResponse)
{
WebSocketAsyncClient client{ctx};
Response const response{http::status::ok, "some response", Request{requestMessage_, Request::HttpHeaders{}}};
Request::HttpHeaders const headers{};
Response const response{http::status::ok, "some response", Request{requestMessage_, headers}};
boost::asio::spawn(ctx, [&](boost::asio::yield_context yield) {
auto maybeError =

View File

@@ -64,7 +64,14 @@ namespace websocket = boost::beast::websocket;
struct ConnectionHandlerTest : SyncAsioContextTest {
ConnectionHandlerTest(ProcessingPolicy policy, std::optional<size_t> maxParallelConnections)
: tagFactory_{util::Config{}}, connectionHandler_{policy, maxParallelConnections, tagFactory_, std::nullopt}
: tagFactory_{util::Config{}}
, connectionHandler_{
policy,
maxParallelConnections,
tagFactory_,
std::nullopt,
onDisconnectMock_.AsStdFunction()
}
{
}
@@ -93,6 +100,7 @@ struct ConnectionHandlerTest : SyncAsioContextTest {
return Request{std::forward<Args>(args)...};
}
testing::StrictMock<testing::MockFunction<void(Connection const&)>> onDisconnectMock_;
util::TagDecoratorFactory tagFactory_;
ConnectionHandler connectionHandler_;
@@ -101,6 +109,8 @@ struct ConnectionHandlerTest : SyncAsioContextTest {
std::make_unique<StrictMockHttpConnection>("1.2.3.4", beast::flat_buffer{}, tagDecoratorFactory_);
StrictMockWsConnectionPtr mockWsConnection_ =
std::make_unique<StrictMockWsConnection>("1.2.3.4", beast::flat_buffer{}, tagDecoratorFactory_);
Request::HttpHeaders headers_;
};
struct ConnectionHandlerSequentialProcessingTest : ConnectionHandlerTest {
@@ -113,6 +123,9 @@ TEST_F(ConnectionHandlerSequentialProcessingTest, ReceiveError)
{
EXPECT_CALL(*mockHttpConnection_, wasUpgraded).WillOnce(Return(false));
EXPECT_CALL(*mockHttpConnection_, receive).WillOnce(Return(makeError(http::error::end_of_stream)));
EXPECT_CALL(onDisconnectMock_, Call).WillOnce([connectionPtr = mockHttpConnection_.get()](Connection const& c) {
EXPECT_EQ(&c, connectionPtr);
});
runSpawn([this](boost::asio::yield_context yield) {
connectionHandler_.processConnection(std::move(mockHttpConnection_), yield);
@@ -124,6 +137,9 @@ TEST_F(ConnectionHandlerSequentialProcessingTest, ReceiveError_CloseConnection)
EXPECT_CALL(*mockHttpConnection_, wasUpgraded).WillOnce(Return(false));
EXPECT_CALL(*mockHttpConnection_, receive).WillOnce(Return(makeError(boost::asio::error::timed_out)));
EXPECT_CALL(*mockHttpConnection_, close);
EXPECT_CALL(onDisconnectMock_, Call).WillOnce([connectionPtr = mockHttpConnection_.get()](Connection const& c) {
EXPECT_EQ(&c, connectionPtr);
});
runSpawn([this](boost::asio::yield_context yield) {
connectionHandler_.processConnection(std::move(mockHttpConnection_), yield);
@@ -134,7 +150,7 @@ TEST_F(ConnectionHandlerSequentialProcessingTest, Receive_Handle_NoHandler_Send)
{
EXPECT_CALL(*mockHttpConnection_, wasUpgraded).WillOnce(Return(false));
EXPECT_CALL(*mockHttpConnection_, receive)
.WillOnce(Return(makeRequest("some_request", Request::HttpHeaders{})))
.WillOnce(Return(makeRequest("some_request", headers_)))
.WillOnce(Return(makeError(websocket::error::closed)));
EXPECT_CALL(*mockHttpConnection_, send).WillOnce([](Response response, auto&&, auto&&) {
@@ -142,6 +158,10 @@ TEST_F(ConnectionHandlerSequentialProcessingTest, Receive_Handle_NoHandler_Send)
return std::nullopt;
});
EXPECT_CALL(onDisconnectMock_, Call).WillOnce([connectionPtr = mockHttpConnection_.get()](Connection const& c) {
EXPECT_EQ(&c, connectionPtr);
});
runSpawn([this](boost::asio::yield_context yield) {
connectionHandler_.processConnection(std::move(mockHttpConnection_), yield);
});
@@ -165,6 +185,10 @@ TEST_F(ConnectionHandlerSequentialProcessingTest, Receive_Handle_BadTarget_Send)
return std::nullopt;
});
EXPECT_CALL(onDisconnectMock_, Call).WillOnce([connectionPtr = mockHttpConnection_.get()](Connection const& c) {
EXPECT_EQ(&c, connectionPtr);
});
runSpawn([this](boost::asio::yield_context yield) {
connectionHandler_.processConnection(std::move(mockHttpConnection_), yield);
});
@@ -182,6 +206,10 @@ TEST_F(ConnectionHandlerSequentialProcessingTest, Receive_Handle_BadMethod_Send)
return std::nullopt;
});
EXPECT_CALL(onDisconnectMock_, Call).WillOnce([connectionPtr = mockHttpConnection_.get()](Connection const& c) {
EXPECT_EQ(&c, connectionPtr);
});
runSpawn([this](boost::asio::yield_context yield) {
connectionHandler_.processConnection(std::move(mockHttpConnection_), yield);
});
@@ -199,7 +227,7 @@ TEST_F(ConnectionHandlerSequentialProcessingTest, Receive_Handle_Send)
EXPECT_CALL(*mockWsConnection_, wasUpgraded).WillOnce(Return(true));
EXPECT_CALL(*mockWsConnection_, receive)
.WillOnce(Return(makeRequest(requestMessage, Request::HttpHeaders{})))
.WillOnce(Return(makeRequest(requestMessage, headers_)))
.WillOnce(Return(makeError(websocket::error::closed)));
EXPECT_CALL(wsHandlerMock, Call).WillOnce([&](Request const& request, auto&&, auto&&, auto&&) {
@@ -212,6 +240,10 @@ TEST_F(ConnectionHandlerSequentialProcessingTest, Receive_Handle_Send)
return std::nullopt;
});
EXPECT_CALL(onDisconnectMock_, Call).WillOnce([connectionPtr = mockWsConnection_.get()](Connection const& c) {
EXPECT_EQ(&c, connectionPtr);
});
runSpawn([this](boost::asio::yield_context yield) {
connectionHandler_.processConnection(std::move(mockWsConnection_), yield);
});
@@ -228,7 +260,7 @@ TEST_F(ConnectionHandlerSequentialProcessingTest, SendSubscriptionMessage)
EXPECT_CALL(*mockWsConnection_, wasUpgraded).WillOnce(Return(true));
EXPECT_CALL(*mockWsConnection_, receive)
.WillOnce(Return(makeRequest("", Request::HttpHeaders{})))
.WillOnce(Return(makeRequest("", headers_)))
.WillOnce(Return(makeError(websocket::error::closed)));
EXPECT_CALL(wsHandlerMock, Call)
@@ -246,6 +278,10 @@ TEST_F(ConnectionHandlerSequentialProcessingTest, SendSubscriptionMessage)
return std::nullopt;
});
EXPECT_CALL(onDisconnectMock_, Call).WillOnce([connectionPtr = mockWsConnection_.get()](Connection const& c) {
EXPECT_EQ(&c, connectionPtr);
});
runSpawn([this](boost::asio::yield_context yield) {
connectionHandler_.processConnection(std::move(mockWsConnection_), yield);
});
@@ -262,7 +298,7 @@ TEST_F(ConnectionHandlerSequentialProcessingTest, SubscriptionContextIsDisconnec
EXPECT_CALL(*mockWsConnection_, wasUpgraded).WillOnce(Return(true));
testing::Expectation const expectationReceiveCalled = EXPECT_CALL(*mockWsConnection_, receive)
.WillOnce(Return(makeRequest("", Request::HttpHeaders{})))
.WillOnce(Return(makeRequest("", headers_)))
.WillOnce(Return(makeError(websocket::error::closed)));
EXPECT_CALL(wsHandlerMock, Call)
@@ -276,6 +312,10 @@ TEST_F(ConnectionHandlerSequentialProcessingTest, SubscriptionContextIsDisconnec
EXPECT_CALL(onDisconnectHook, Call).After(expectationReceiveCalled);
EXPECT_CALL(onDisconnectMock_, Call).WillOnce([connectionPtr = mockWsConnection_.get()](Connection const& c) {
EXPECT_EQ(&c, connectionPtr);
});
runSpawn([this](boost::asio::yield_context yield) {
connectionHandler_.processConnection(std::move(mockWsConnection_), yield);
});
@@ -314,6 +354,10 @@ TEST_F(ConnectionHandlerSequentialProcessingTest, SubscriptionContextIsNullForHt
EXPECT_CALL(*mockHttpConnection_, close);
EXPECT_CALL(onDisconnectMock_, Call).WillOnce([connectionPtr = mockHttpConnection_.get()](Connection const& c) {
EXPECT_EQ(&c, connectionPtr);
});
runSpawn([this](boost::asio::yield_context yield) {
connectionHandler_.processConnection(std::move(mockHttpConnection_), yield);
});
@@ -354,6 +398,10 @@ TEST_F(ConnectionHandlerSequentialProcessingTest, Receive_Handle_Send_Loop)
EXPECT_CALL(*mockHttpConnection_, close);
EXPECT_CALL(onDisconnectMock_, Call).WillOnce([connectionPtr = mockHttpConnection_.get()](Connection const& c) {
EXPECT_EQ(&c, connectionPtr);
});
runSpawn([this](boost::asio::yield_context yield) {
connectionHandler_.processConnection(std::move(mockHttpConnection_), yield);
});
@@ -385,6 +433,10 @@ TEST_F(ConnectionHandlerSequentialProcessingTest, Receive_Handle_SendError)
return makeError(http::error::end_of_stream).error();
});
EXPECT_CALL(onDisconnectMock_, Call).WillOnce([connectionPtr = mockHttpConnection_.get()](Connection const& c) {
EXPECT_EQ(&c, connectionPtr);
});
runSpawn([this](boost::asio::yield_context yield) {
connectionHandler_.processConnection(std::move(mockHttpConnection_), yield);
});
@@ -408,7 +460,7 @@ TEST_F(ConnectionHandlerSequentialProcessingTest, Stop)
if (connectionClosed) {
return makeError(websocket::error::closed);
}
return makeRequest(requestMessage, Request::HttpHeaders{});
return makeRequest(requestMessage, headers_);
});
EXPECT_CALL(wsHandlerMock, Call).Times(3).WillRepeatedly([&](Request const& request, auto&&, auto&&, auto&&) {
@@ -429,6 +481,10 @@ TEST_F(ConnectionHandlerSequentialProcessingTest, Stop)
EXPECT_CALL(*mockWsConnection_, close).WillOnce([&connectionClosed]() { connectionClosed = true; });
EXPECT_CALL(onDisconnectMock_, Call).WillOnce([connectionPtr = mockWsConnection_.get()](Connection const& c) {
EXPECT_EQ(&c, connectionPtr);
});
runSpawn([this](boost::asio::yield_context yield) {
connectionHandler_.processConnection(std::move(mockWsConnection_), yield);
});
@@ -459,6 +515,10 @@ TEST_F(ConnectionHandlerParallelProcessingTest, ReceiveError)
EXPECT_CALL(*mockHttpConnection_, wasUpgraded).WillOnce(Return(false));
EXPECT_CALL(*mockHttpConnection_, receive).WillOnce(Return(makeError(http::error::end_of_stream)));
EXPECT_CALL(onDisconnectMock_, Call).WillOnce([connectionPtr = mockHttpConnection_.get()](Connection const& c) {
EXPECT_EQ(&c, connectionPtr);
});
runSpawn([this](boost::asio::yield_context yield) {
connectionHandler_.processConnection(std::move(mockHttpConnection_), yield);
});
@@ -476,7 +536,7 @@ TEST_F(ConnectionHandlerParallelProcessingTest, Receive_Handle_Send)
EXPECT_CALL(*mockWsConnection_, wasUpgraded).WillOnce(Return(true));
EXPECT_CALL(*mockWsConnection_, receive)
.WillOnce(Return(makeRequest(requestMessage, Request::HttpHeaders{})))
.WillOnce(Return(makeRequest(requestMessage, headers_)))
.WillOnce(Return(makeError(websocket::error::closed)));
EXPECT_CALL(wsHandlerMock, Call).WillOnce([&](Request const& request, auto&&, auto&&, auto&&) {
@@ -489,6 +549,10 @@ TEST_F(ConnectionHandlerParallelProcessingTest, Receive_Handle_Send)
return std::nullopt;
});
EXPECT_CALL(onDisconnectMock_, Call).WillOnce([connectionPtr = mockWsConnection_.get()](Connection const& c) {
EXPECT_EQ(&c, connectionPtr);
});
runSpawn([this](boost::asio::yield_context yield) {
connectionHandler_.processConnection(std::move(mockWsConnection_), yield);
});
@@ -504,7 +568,7 @@ TEST_F(ConnectionHandlerParallelProcessingTest, Receive_Handle_Send_Loop)
std::string const requestMessage = "some message";
std::string const responseMessage = "some response";
auto const returnRequest = [&](auto&&, auto&&) { return makeRequest(requestMessage, Request::HttpHeaders{}); };
auto const returnRequest = [&](auto&&, auto&&) { return makeRequest(requestMessage, headers_); };
EXPECT_CALL(*mockWsConnection_, wasUpgraded).WillOnce(Return(true));
EXPECT_CALL(*mockWsConnection_, receive)
@@ -524,6 +588,10 @@ TEST_F(ConnectionHandlerParallelProcessingTest, Receive_Handle_Send_Loop)
return std::nullopt;
});
EXPECT_CALL(onDisconnectMock_, Call).WillOnce([connectionPtr = mockWsConnection_.get()](Connection const& c) {
EXPECT_EQ(&c, connectionPtr);
});
runSpawn([this](boost::asio::yield_context yield) {
connectionHandler_.processConnection(std::move(mockWsConnection_), yield);
});
@@ -539,7 +607,7 @@ TEST_F(ConnectionHandlerParallelProcessingTest, Receive_Handle_Send_Loop_TooMany
std::string const requestMessage = "some message";
std::string const responseMessage = "some response";
auto const returnRequest = [&](auto&&, auto&&) { return makeRequest(requestMessage, Request::HttpHeaders{}); };
auto const returnRequest = [&](auto&&, auto&&) { return makeRequest(requestMessage, headers_); };
testing::Sequence const sequence;
EXPECT_CALL(*mockWsConnection_, wasUpgraded).WillOnce(Return(true));
@@ -583,6 +651,10 @@ TEST_F(ConnectionHandlerParallelProcessingTest, Receive_Handle_Send_Loop_TooMany
.Times(2)
.WillRepeatedly(Return(std::nullopt));
EXPECT_CALL(onDisconnectMock_, Call).WillOnce([connectionPtr = mockWsConnection_.get()](Connection const& c) {
EXPECT_EQ(&c, connectionPtr);
});
runSpawn([this](boost::asio::yield_context yield) {
connectionHandler_.processConnection(std::move(mockWsConnection_), yield);
});

View File

@@ -49,7 +49,8 @@ struct ng_ErrorHandlingTests : NoLoggerFixture {
{
if (isHttp)
return Request{http::request<http::string_body>{http::verb::post, "/", 11, body.value_or("")}};
return Request{body.value_or(""), Request::HttpHeaders{}};
static Request::HttpHeaders const headers_;
return Request{body.value_or(""), headers_};
}
};

View File

@@ -51,7 +51,8 @@ struct web_WsConnectionTests : SyncAsioContextTest {
util::TagDecoratorFactory tagDecoratorFactory_{util::Config{boost::json::object{{"log_tag_style", "int"}}}};
TestHttpServer httpServer_{ctx, "localhost"};
WebSocketAsyncClient wsClient_{ctx};
Request request_{"some request", Request::HttpHeaders{}};
Request::HttpHeaders const headers_;
Request request_{"some request", headers_};
std::unique_ptr<PlainWsConnection>
acceptConnection(boost::asio::yield_context yield)