diff --git a/modules/ripple_asio/sockets/ripple_MultiSocket.cpp b/modules/ripple_asio/sockets/ripple_MultiSocket.cpp index 9fd2e0ce7..a1f44b71e 100644 --- a/modules/ripple_asio/sockets/ripple_MultiSocket.cpp +++ b/modules/ripple_asio/sockets/ripple_MultiSocket.cpp @@ -4,6 +4,8 @@ */ //============================================================================== +#define MULTISOCKET_TEST_PROXY 1 + MultiSocket* MultiSocket::New (boost::asio::io_service& io_service, int flags) { return new MultiSocketType (io_service, flags); @@ -155,25 +157,46 @@ public: //-------------------------------------------------------------------------- template - void run_async (ClientArg const& clientArg, ServerArg const& serverArg) + void runProxy (ClientArg const& clientArg, ServerArg const& serverArg) { PeerTest::run , - TestPeerLogicAsyncClient, TestPeerLogicAsyncServer> (clientArg, serverArg, timeoutSeconds).report (*this); + TestPeerLogicProxyClient, TestPeerLogicSyncServer> (clientArg, serverArg, timeoutSeconds).report (*this); + + PeerTest::run , + TestPeerLogicProxyClient, TestPeerLogicAsyncServer> (clientArg, serverArg, timeoutSeconds).report (*this); } + //-------------------------------------------------------------------------- + template void run (ClientArg const& clientArg, ServerArg const& serverArg) { PeerTest::run , - TestPeerLogicSyncClient, TestPeerLogicSyncServer> (clientArg, serverArg, timeoutSeconds).report (*this); + TestPeerLogicSyncClient, TestPeerLogicSyncServer> + (clientArg, serverArg, timeoutSeconds).report (*this); PeerTest::run , - TestPeerLogicAsyncClient, TestPeerLogicSyncServer> (clientArg, serverArg, timeoutSeconds).report (*this); + TestPeerLogicAsyncClient, TestPeerLogicSyncServer> + (clientArg, serverArg, timeoutSeconds).report (*this); PeerTest::run , - TestPeerLogicSyncClient, TestPeerLogicAsyncServer> (clientArg, serverArg, timeoutSeconds).report (*this); + TestPeerLogicSyncClient, TestPeerLogicAsyncServer> + (clientArg, serverArg, timeoutSeconds).report (*this); - run_async (clientArg, serverArg); + PeerTest::run , + TestPeerLogicAsyncClient, TestPeerLogicAsyncServer> + (clientArg, serverArg, timeoutSeconds).report (*this); + } + + //-------------------------------------------------------------------------- + + template + void testProxyFlags (int extraClientFlags, int extraServerFlags) + { + check_precondition (! MultiSocket::Flag (extraClientFlags).any_set (MultiSocket::Flag::client_role | MultiSocket::Flag::server_role)); + + runProxy (MultiSocket::Flag::client_role | extraClientFlags, + MultiSocket::Flag::server_role | extraServerFlags); } //-------------------------------------------------------------------------- @@ -183,7 +206,8 @@ public: { check_precondition (! MultiSocket::Flag (extraClientFlags).any_set (MultiSocket::Flag::client_role | MultiSocket::Flag::server_role)); - run (MultiSocket::Flag::client_role | extraClientFlags, MultiSocket::Flag::server_role | extraServerFlags); + run (MultiSocket::Flag::client_role | extraClientFlags, + MultiSocket::Flag::server_role | extraServerFlags); } template @@ -199,6 +223,11 @@ public: testFlags (MultiSocket::Flag::ssl, MultiSocket::Flag::ssl_required); testFlags (0, MultiSocket::Flag::ssl); testFlags (MultiSocket::Flag::ssl, MultiSocket::Flag::ssl); + + testProxyFlags (MultiSocket::Flag::proxy, MultiSocket::Flag::proxy); + testProxyFlags (MultiSocket::Flag::proxy, MultiSocket::Flag::proxy | MultiSocket::Flag::ssl); + testProxyFlags (MultiSocket::Flag::proxy | MultiSocket::Flag::ssl, MultiSocket::Flag::proxy | MultiSocket::Flag::ssl); + testProxyFlags (MultiSocket::Flag::proxy | MultiSocket::Flag::ssl, MultiSocket::Flag::proxy | MultiSocket::Flag::ssl_required); } void runTest () diff --git a/modules/ripple_asio/sockets/ripple_MultiSocketType.h b/modules/ripple_asio/sockets/ripple_MultiSocketType.h index e74acb36b..432344917 100644 --- a/modules/ripple_asio/sockets/ripple_MultiSocketType.h +++ b/modules/ripple_asio/sockets/ripple_MultiSocketType.h @@ -32,12 +32,6 @@ public: typedef typename next_layer_type::lowest_layer_type lowest_layer_type; typedef typename boost::asio::ssl::stream ssl_stream_type; - // These types are used internaly - // - - // The type of stream that the HandshakeDetectStream uses - typedef next_layer_type& ssl_detector_stream_type; - template explicit MultiSocketType (Arg& arg, int flags = 0) : m_flags (cleaned_flags (flags)) @@ -158,8 +152,19 @@ protected: bool requires_handshake () { + if (is_client ()) + { + if (m_flags.set (Flag::ssl)) + return true; + + return false; + } + + //bassert (is_server ()); + if (m_stream != nullptr) return stream ().requires_handshake (); + return true; } @@ -300,7 +305,6 @@ protected: using namespace boost; if (asio::buffer_size (buffers) > 0) { - bassert (asio::buffer_size (buffers) > 0); typedef PrefilledReadStream next_type; typedef asio::ssl::stream this_type; SocketWrapper * const socket = new SocketWrapper ( @@ -311,16 +315,44 @@ protected: return new_ssl_stream (); } + // Creates a stream that detects whether or not an + // SSL handshake is being attempted. + // Socket* new_detect_ssl_stream () { SSL3DetectCallback* callback = new SSL3DetectCallback (this); - typedef HandshakeDetectStreamType this_type; + typedef HandshakeDetectStreamType this_type; return new SocketWrapper (callback, m_next_layer); } - Socket* new_proxy_stream () + template + Socket* new_detect_ssl_stream (ConstBufferSequence const& buffers) { - return nullptr; + using namespace boost; + if (asio::buffer_size (buffers) > 0) + { + SSL3DetectCallback* callback = new SSL3DetectCallback (this); + typedef HandshakeDetectStreamType this_type; + SocketWrapper * const socket = + new SocketWrapper (callback, m_next_layer); + socket->this_layer ().fill (buffers); + return socket; + } + return new_detect_ssl_stream (); + } + + // Creates a stream that detects whether or not a PROXY + // handshake is being sent. + // + template + Socket* new_proxy_stream (ConstBufferSequence const& buffers) + { + // Doesn't make sense to already have data if we're expecting PROXY + bassert (boost::asio::buffer_size (buffers) == 0); + + PROXYDetectCallback* callback = new PROXYDetectCallback (this); + typedef HandshakeDetectStreamType this_type; + return new SocketWrapper (callback, m_next_layer); } //-------------------------------------------------------------------------- @@ -328,7 +360,7 @@ protected: // Returns a stream appropriate for the current settings. If the // return value is nullptr, it means that the stream cannot be determined // until handshake or async_handshake is called. At that point, it will - // be necessary to use new_handshake_stream + // be necessary to use new_handshake_stream. // Socket* new_stream () { @@ -338,9 +370,8 @@ protected: { if (m_flags.set (Flag::proxy)) { - // sending PROXY unimplemented for now - SocketBase::pure_virtual (); - socket = nullptr; + // Client sends PROXY in the plain + socket = new_plain_stream (); } else if (! m_flags.any_set ( Flag::ssl)) @@ -401,13 +432,18 @@ protected: // // Returns true if the stream was set // - bool set_handshake_stream (handshake_type type, error_code& ec) + bool set_handshake_stream (handshake_type type, + error_code& ec, ConstBuffers const& buffers = ConstBuffers()) { + // VFALCO can't really assert on this because what if + // there was a previous stream? + // // We better require a handshake if this is getting called - bassert (requires_handshake ()); + //bassert (requires_handshake ()); // If a stream already exists, some logic went wrong. - bassert (m_stream == nullptr); + // Can't assert on this either, for the case of going proxy->ssldetect + //bassert (m_stream == nullptr); ec = error_code (); @@ -423,8 +459,8 @@ protected: // If we constructed in a server or client role and // the handshake type doesn't match, return an error. // - if ( (type == client && ! m_flags.set (Flag::client_role)) || - (type == server && ! m_flags.set (Flag::server_role))) + if ( (type == client && ! is_client ()) || + (type == server && ! is_server ())) { handshake_error (ec); return false; @@ -432,37 +468,40 @@ protected: Socket* socket = nullptr; - if (m_flags.set (Flag::client_role)) + if (is_client ()) { - // should have filtered this out already - bassert (! m_flags.set (Flag::proxy)); + // can't assert on this because we dont control the send, + // only when they decide to perform the handshake. + //bassert (! m_flags.set (Flag::proxy)); if (m_flags.set (Flag::ssl)) { - socket = new_ssl_stream (); + socket = new_ssl_stream (buffers); } else { - // bad + socket = new_plain_stream (buffers); } } else { + bassert (is_server ()); + if (m_flags.set (Flag::proxy)) { - socket = new_proxy_stream (); + socket = new_proxy_stream (buffers); } else if (m_flags.set (Flag::ssl_required)) { - socket = new_ssl_stream (); + socket = new_ssl_stream (buffers); } else if (m_flags.set (Flag::ssl)) { - socket = new_detect_ssl_stream (); + socket = new_detect_ssl_stream (buffers); } else { - // bad + socket = new_plain_stream (buffers); } } @@ -475,13 +514,14 @@ protected: //-------------------------------------------------------------------------- - void on_detect_ssl (error_code& ec, ConstBuffers const& buffers, bool is_ssl) + void on_detect_ssl (HandshakeDetectLogicSSL3& logic, + error_code& ec, ConstBuffers const& buffers) { bassert (is_server ()); if (! ec) { - if (is_ssl) + if (logic.success ()) { #if BEAST_ASIO_HAS_BUFFEREDHANDSHAKE if (useBufferedHandshake) @@ -502,12 +542,12 @@ protected: } // origHandler cannot be a reference since it gets move-copied - void on_async_detect_ssl (error_code const& ec, - ConstBuffers const& buffers, ErrorCall origHandler, bool is_ssl) + void on_async_detect_ssl (HandshakeDetectLogicSSL3& logic, + error_code const& ec, ConstBuffers const& buffers, ErrorCall origHandler) { if (! ec) { - if (is_ssl) + if (logic.success ()) { m_stream = new_ssl_stream (buffers); // Is this a continuation? @@ -527,14 +567,14 @@ protected: #if BEAST_ASIO_HAS_BUFFEREDHANDSHAKE // origHandler cannot be a reference since it gets move-copied - void on_async_detect_ssl (error_code const& ec, - ConstBuffers const& buffers, TransferCall origHandler, bool is_ssl) + void on_async_detect_ssl (HandshakeDetectLogicSSL3& logic, + error_code const& ec, ConstBuffers const& buffers, TransferCall origHandler) { std::size_t bytes_transferred = 0; if (! ec) { - if (is_ssl) + if (logic.success ()) { // Is this a continuation? m_stream = new_ssl_stream (); @@ -556,9 +596,98 @@ protected: } #endif + //-------------------------------------------------------------------------- + + void setProxyInfo (HandshakeDetectLogicPROXY::ProxyInfo const proxyInfo) + { + // Do something with it + } + + void on_detect_proxy (HandshakeDetectLogicPROXY& logic, + error_code& ec, ConstBuffers const& buffers) + { + bassert (is_server ()); + + if (ec) + return; + + if (! logic.success ()) + { + handshake_error (ec); + return; + } + + // Remove the proxy flag + m_flags = m_flags.without (Flag::proxy); + + setProxyInfo (logic.getInfo ()); + + if (set_handshake_stream (Socket::server, ec, buffers)) + { + if (stream ().requires_handshake ()) + stream ().handshake (Socket::server, ec); + } + } + + // origHandler cannot be a reference since it gets move-copied + void on_async_detect_proxy (HandshakeDetectLogicPROXY& logic, + error_code ec, ConstBuffers const& buffers, ErrorCall origHandler) + { + if (! ec) + { + if (! logic.success ()) + handshake_error (ec); + } + + if (! ec) + { + // Remove the proxy flag + m_flags = m_flags.without (Flag::proxy); + + setProxyInfo (logic.getInfo ()); + + if (set_handshake_stream (Socket::server, ec, buffers)) + { + if (stream ().requires_handshake ()) + { + stream ().async_handshake (Socket::server, origHandler); + } + else + { + get_io_service ().post (CompletionCall (origHandler, ec)); + } + + return; + } + else + { + bassert (ec); + } + } + + if (ec) + get_io_service ().post (CompletionCall (origHandler, ec)); + } + +#if BEAST_ASIO_HAS_BUFFEREDHANDSHAKE + // origHandler cannot be a reference since it gets move-copied + void on_async_detect_proxy (HandshakeDetectLogicPROXY& logic, + error_code const& ec, ConstBuffers const& buffers, TransferCall origHandler) + { + std::size_t bytes_transferred = 0; + + if (! ec) + { + if (logic.success ()) + { + } + } + } +#endif + //-------------------------------------------------------------------------- // - // HandshakeDetectStream calback + // SSL3 Callback // class SSL3DetectCallback : public HandshakeDetectStream ::Callback @@ -574,20 +703,58 @@ protected: void on_detect (LogicType& logic, error_code& ec, ConstBuffers const& buffers) { - m_owner->on_detect_ssl (ec, buffers, logic.success ()); + m_owner->on_detect_ssl (logic, ec, buffers); } void on_async_detect (LogicType& logic, error_code const& ec, ConstBuffers const& buffers, ErrorCall const& origHandler) { - m_owner->on_async_detect_ssl (ec, buffers, origHandler, logic.success ()); + m_owner->on_async_detect_ssl (logic, ec, buffers, origHandler); } #if BEAST_ASIO_HAS_BUFFEREDHANDSHAKE void on_async_detect (LogicType& logic, error_code const& ec, ConstBuffers const& buffers, TransferCall const& origHandler) { - m_owner->on_async_detect_ssl (ec, buffers, origHandler, logic.success ()); + m_owner->on_async_detect_ssl (logic, ec, buffers, origHandler); + } + #endif + + private: + MultiSocketType * m_owner; + }; + + //-------------------------------------------------------------------------- + // + // PROXY callback + // + class PROXYDetectCallback + : public HandshakeDetectStream ::Callback + { + public: + typedef HandshakeDetectLogicPROXY LogicType; + + explicit PROXYDetectCallback (MultiSocketType * owner) + : m_owner (owner) + { + } + + void on_detect (LogicType& logic, error_code& ec, ConstBuffers const& buffers) + { + m_owner->on_detect_proxy (logic, ec, buffers); + } + + void on_async_detect (LogicType& logic, error_code const& ec, + ConstBuffers const& buffers, ErrorCall const& origHandler) + { + m_owner->on_async_detect_proxy (logic, ec, buffers, origHandler); + } + + #if BEAST_ASIO_HAS_BUFFEREDHANDSHAKE + void on_async_detect (LogicType& logic, error_code const& ec, + ConstBuffers const& buffers, TransferCall const& origHandler) + { + m_owner->on_async_detect_proxy (logic, ec, buffers, origHandler); } #endif