#include "util/TestWsServer.hpp" #include "util/requests/Types.hpp" #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include namespace asio = boost::asio; namespace websocket = boost::beast::websocket; TestWsConnection::TestWsConnection( websocket::stream wsStream, std::vector headers ) : ws_(std::move(wsStream)), headers_(std::move(headers)) { } TestWsConnection::TestWsConnection(TestWsConnection&& other) : ws_(std::move(other.ws_)), headers_(std::move(other.headers_)) { } std::optional TestWsConnection::send(std::string const& message, boost::asio::yield_context yield) { boost::beast::error_code errorCode; ws_.async_write(asio::buffer(message), yield[errorCode]); if (errorCode) return errorCode.message(); return std::nullopt; } void TestWsConnection::sendPing( boost::beast::websocket::ping_data const& data, boost::asio::yield_context yield ) { boost::beast::error_code errorCode; ws_.async_ping(data, yield[errorCode]); [&]() { ASSERT_FALSE(errorCode) << errorCode.message(); }(); } std::optional TestWsConnection::receive(boost::asio::yield_context yield) { boost::beast::error_code errorCode; boost::beast::flat_buffer buffer; ws_.async_read(buffer, yield[errorCode]); if (errorCode == websocket::error::closed) return std::nullopt; [&]() { ASSERT_FALSE(errorCode) << errorCode.message(); }(); return boost::beast::buffers_to_string(buffer.data()); } std::optional TestWsConnection::close(boost::asio::yield_context yield) { boost::beast::error_code errorCode; ws_.async_close(websocket::close_code::normal, yield[errorCode]); if (errorCode) return errorCode.message(); return std::nullopt; } std::vector const& TestWsConnection::headers() const { return headers_; } void TestWsConnection::setControlFrameCallback( std::function callback ) { ws_.control_callback(std::move(callback)); } void TestWsConnection::resetControlFrameCallback() { ws_.control_callback(); } TestWsServer::TestWsServer(asio::io_context& context, std::string const& host) : acceptor_(context) { auto endpoint = asio::ip::tcp::endpoint(boost::asio::ip::make_address(host), 0); acceptor_.open(endpoint.protocol()); acceptor_.set_option(asio::socket_base::reuse_address(true)); acceptor_.bind(endpoint); } std::string TestWsServer::port() const { return std::to_string(this->acceptor_.local_endpoint().port()); } std::expected TestWsServer::acceptConnection(asio::yield_context yield) { acceptor_.listen(asio::socket_base::max_listen_connections); boost::beast::error_code errorCode; asio::ip::tcp::socket socket(acceptor_.get_executor()); acceptor_.async_accept(socket, yield[errorCode]); if (errorCode) return std::unexpected{util::requests::RequestError{"Accept error", errorCode}}; boost::beast::flat_buffer buffer; boost::beast::http::request request; boost::beast::http::async_read(socket, buffer, request, yield[errorCode]); if (errorCode) return std::unexpected{util::requests::RequestError{"Read error", errorCode}}; std::vector headers; std::ranges::transform(request, std::back_inserter(headers), [](auto const& header) { if (header.name() == boost::beast::http::field::unknown) return util::requests::HttpHeader{header.name_string(), header.value()}; return util::requests::HttpHeader{header.name(), header.value()}; }); if (not boost::beast::websocket::is_upgrade(request)) return std::unexpected{util::requests::RequestError{"Not a websocket request"}}; boost::beast::websocket::stream ws(std::move(socket)); ws.set_option(websocket::stream_base::timeout::suggested(boost::beast::role_type::server)); ws.async_accept(request, yield[errorCode]); if (errorCode) return std::unexpected{util::requests::RequestError{"Handshake error", errorCode}}; return TestWsConnection(std::move(ws), std::move(headers)); } void TestWsServer::acceptConnectionAndDropIt(asio::yield_context yield) { acceptConnectionWithoutHandshake(yield); } boost::asio::ip::tcp::socket TestWsServer::acceptConnectionWithoutHandshake(boost::asio::yield_context yield) { acceptor_.listen(asio::socket_base::max_listen_connections); boost::beast::error_code errorCode; asio::ip::tcp::socket socket(acceptor_.get_executor()); acceptor_.async_accept(socket, yield[errorCode]); [&]() { ASSERT_FALSE(errorCode) << errorCode.message(); }(); return socket; }