#ifndef XRPL_WEBSOCKET_AUTOSOCKET_AUTOSOCKET_H_INCLUDED #define XRPL_WEBSOCKET_AUTOSOCKET_AUTOSOCKET_H_INCLUDED #include #include #include #include #include #include // Socket wrapper that supports both SSL and non-SSL connections. // Generally, handle it as you would an SSL connection. // To force a non-SSL connection, just don't call async_handshake. // To force SSL only inbound, call setSSLOnly. class AutoSocket { public: using ssl_socket = boost::asio::ssl::stream; using endpoint_type = boost::asio::ip::tcp::socket::endpoint_type; using socket_ptr = std::unique_ptr; using plain_socket = ssl_socket::next_layer_type; using lowest_layer_type = ssl_socket::lowest_layer_type; using handshake_type = ssl_socket::handshake_type; using error_code = boost::system::error_code; using callback = std::function; public: AutoSocket( boost::asio::io_context& s, boost::asio::ssl::context& c, bool secureOnly, bool plainOnly) : mSecure(secureOnly) , mBuffer((plainOnly || secureOnly) ? 0 : 4) , j_{beast::Journal::getNullSink()} { mSocket = std::make_unique(s, c); } AutoSocket(boost::asio::io_context& s, boost::asio::ssl::context& c) : AutoSocket(s, c, false, false) { } bool isSecure() { return mSecure; } ssl_socket& SSLSocket() { return *mSocket; } plain_socket& PlainSocket() { return mSocket->next_layer(); } beast::IP::Endpoint local_endpoint() { return beast::IP::from_asio(lowest_layer().local_endpoint()); } beast::IP::Endpoint remote_endpoint() { return beast::IP::from_asio(lowest_layer().remote_endpoint()); } lowest_layer_type& lowest_layer() { return mSocket->lowest_layer(); } void swap(AutoSocket& s) noexcept { mBuffer.swap(s.mBuffer); mSocket.swap(s.mSocket); std::swap(mSecure, s.mSecure); } boost::system::error_code cancel(boost::system::error_code& ec) { return lowest_layer().cancel(ec); } void async_handshake(handshake_type type, callback cbFunc) { if ((type == ssl_socket::client) || (mSecure)) { // must be ssl mSecure = true; mSocket->async_handshake(type, cbFunc); } else if (mBuffer.empty()) { // must be plain mSecure = false; post( mSocket->get_executor(), boost::beast::bind_handler(cbFunc, error_code())); } else { // autodetect mSocket->next_layer().async_receive( boost::asio::buffer(mBuffer), boost::asio::socket_base::message_peek, std::bind( &AutoSocket::handle_autodetect, this, cbFunc, std::placeholders::_1, std::placeholders::_2)); } } template void async_shutdown(ShutdownHandler handler) { if (isSecure()) mSocket->async_shutdown(handler); else { error_code ec; try { lowest_layer().shutdown(plain_socket::shutdown_both); } catch (boost::system::system_error& e) { ec = e.code(); } post( mSocket->get_executor(), boost::beast::bind_handler(handler, ec)); } } template void async_read_some(Seq const& buffers, Handler handler) { if (isSecure()) mSocket->async_read_some(buffers, handler); else PlainSocket().async_read_some(buffers, handler); } template void async_read_until(Seq const& buffers, Condition condition, Handler handler) { if (isSecure()) boost::asio::async_read_until( *mSocket, buffers, condition, handler); else boost::asio::async_read_until( PlainSocket(), buffers, condition, handler); } template void async_read_until( boost::asio::basic_streambuf& buffers, std::string const& delim, Handler handler) { if (isSecure()) boost::asio::async_read_until(*mSocket, buffers, delim, handler); else boost::asio::async_read_until( PlainSocket(), buffers, delim, handler); } template void async_read_until( boost::asio::basic_streambuf& buffers, MatchCondition cond, Handler handler) { if (isSecure()) boost::asio::async_read_until(*mSocket, buffers, cond, handler); else boost::asio::async_read_until( PlainSocket(), buffers, cond, handler); } template void async_write(Buf const& buffers, Handler handler) { if (isSecure()) boost::asio::async_write(*mSocket, buffers, handler); else boost::asio::async_write(PlainSocket(), buffers, handler); } template void async_write( boost::asio::basic_streambuf& buffers, Handler handler) { if (isSecure()) boost::asio::async_write(*mSocket, buffers, handler); else boost::asio::async_write(PlainSocket(), buffers, handler); } template void async_read(Buf const& buffers, Condition cond, Handler handler) { if (isSecure()) boost::asio::async_read(*mSocket, buffers, cond, handler); else boost::asio::async_read(PlainSocket(), buffers, cond, handler); } template void async_read( boost::asio::basic_streambuf& buffers, Condition cond, Handler handler) { if (isSecure()) boost::asio::async_read(*mSocket, buffers, cond, handler); else boost::asio::async_read(PlainSocket(), buffers, cond, handler); } template void async_read(Buf const& buffers, Handler handler) { if (isSecure()) boost::asio::async_read(*mSocket, buffers, handler); else boost::asio::async_read(PlainSocket(), buffers, handler); } template void async_write_some(Seq const& buffers, Handler handler) { if (isSecure()) mSocket->async_write_some(buffers, handler); else PlainSocket().async_write_some(buffers, handler); } protected: void handle_autodetect( callback cbFunc, error_code const& ec, size_t bytesTransferred) { using namespace ripple; if (ec) { JLOG(j_.warn()) << "Handle autodetect error: " << ec; cbFunc(ec); } else if ( (mBuffer[0] < 127) && (mBuffer[0] > 31) && ((bytesTransferred < 2) || ((mBuffer[1] < 127) && (mBuffer[1] > 31))) && ((bytesTransferred < 3) || ((mBuffer[2] < 127) && (mBuffer[2] > 31))) && ((bytesTransferred < 4) || ((mBuffer[3] < 127) && (mBuffer[3] > 31)))) { // not ssl JLOG(j_.trace()) << "non-SSL"; mSecure = false; cbFunc(ec); } else { // ssl JLOG(j_.trace()) << "SSL"; mSecure = true; mSocket->async_handshake(ssl_socket::server, cbFunc); } } private: socket_ptr mSocket; bool mSecure; std::vector mBuffer; beast::Journal j_; }; #endif