diff --git a/BeastConfig.h b/BeastConfig.h index 6807bb054d..54cefe5c0a 100644 --- a/BeastConfig.h +++ b/BeastConfig.h @@ -143,6 +143,18 @@ #define BEAST_DISABLE_CONTRACT_CHECKS 0 #endif +//------------------------------------------------------------------------------ + +/** Config: BEAST_COMPILER_CHECKS_SOCKET_OVERRIDES + + Setting this option makes Socket-derived classes generate compile errors if + they forget any of the virtual overrides As some Socket-derived classes + intentionally omit member functions that are not applicable, this macro + should only be enabled temporarily when writing your own Socket-derived class, + to make sure that the function signatures match as expected. +*/ +#define BEAST_COMPILER_CHECKS_SOCKET_OVERRIDES 0 + //------------------------------------------------------------------------------ // // Ripple compilation settings @@ -161,7 +173,7 @@ #define RIPPLE_USES_BEAST_SOCKETS 0 -#define BEAST_ASIO_HAS_BUFFEREDHANDSHAKE 0 +//#define BEAST_ASIO_HAS_BUFFEREDHANDSHAKE 0 //#define BEAST_ASIO_HAS_FUTURE_RETURNS 0 #endif diff --git a/Builds/VisualStudio2012/RippleD.vcxproj b/Builds/VisualStudio2012/RippleD.vcxproj index ecf7767f0c..5a526bb5a2 100644 --- a/Builds/VisualStudio2012/RippleD.vcxproj +++ b/Builds/VisualStudio2012/RippleD.vcxproj @@ -1665,7 +1665,6 @@ - diff --git a/Builds/VisualStudio2012/RippleD.vcxproj.filters b/Builds/VisualStudio2012/RippleD.vcxproj.filters index 5cf4f6ab53..95bf06a7da 100644 --- a/Builds/VisualStudio2012/RippleD.vcxproj.filters +++ b/Builds/VisualStudio2012/RippleD.vcxproj.filters @@ -1724,9 +1724,6 @@ - - %28Notes%29 - %28Notes%29 diff --git a/modules/ripple_asio/sockets/ripple_MultiSocket.cpp b/modules/ripple_asio/sockets/ripple_MultiSocket.cpp index cf5233cb19..49a129dfa4 100644 --- a/modules/ripple_asio/sockets/ripple_MultiSocket.cpp +++ b/modules/ripple_asio/sockets/ripple_MultiSocket.cpp @@ -4,8 +4,6 @@ */ //============================================================================== -#define MULTISOCKET_TEST_PROXY 1 - MultiSocket* MultiSocket::New (boost::asio::io_service& io_service, int flags) { return new MultiSocketType (io_service, flags); @@ -183,6 +181,7 @@ public: TestPeerLogicSyncClient, TestPeerLogicAsyncServer> (clientArg, serverArg, timeoutSeconds).report (*this); + PeerTest::run , TestPeerLogicAsyncClient, TestPeerLogicAsyncServer> (clientArg, serverArg, timeoutSeconds).report (*this); @@ -213,21 +212,41 @@ public: template void testProtocol () { - // These should pass. - run (0, 0); - run (MultiSocket::Flag::client_role, 0); + // Simple tests + run (0, + 0); + + run (MultiSocket::Flag::client_role, + 0); + run (0, MultiSocket::Flag::server_role); - run (MultiSocket::Flag::client_role, MultiSocket::Flag::server_role); - // These should pass - testFlags (MultiSocket::Flag::ssl, MultiSocket::Flag::ssl_required); - testFlags (0, MultiSocket::Flag::ssl); - testFlags (MultiSocket::Flag::ssl, MultiSocket::Flag::ssl); + run (MultiSocket::Flag::client_role, + MultiSocket::Flag::server_role); - testProxyFlags (MultiSocket::Flag::proxy, MultiSocket::Flag::proxy); - testProxyFlags (MultiSocket::Flag::proxy | MultiSocket::Flag::ssl, MultiSocket::Flag::proxy | MultiSocket::Flag::ssl_required); - testProxyFlags (MultiSocket::Flag::proxy, MultiSocket::Flag::proxy | MultiSocket::Flag::ssl); - testProxyFlags (MultiSocket::Flag::proxy | MultiSocket::Flag::ssl, MultiSocket::Flag::proxy | MultiSocket::Flag::ssl); + testFlags (MultiSocket::Flag::ssl, + MultiSocket::Flag::ssl_required); + + // SSL-Detect tests + testFlags (0, + MultiSocket::Flag::ssl); + + testFlags (MultiSocket::Flag::ssl, + MultiSocket::Flag::ssl); + + // PROXY Handshake tests + testProxyFlags (MultiSocket::Flag::proxy, + MultiSocket::Flag::proxy); + + testProxyFlags (MultiSocket::Flag::proxy | MultiSocket::Flag::ssl, + MultiSocket::Flag::proxy | MultiSocket::Flag::ssl_required); + + // PROXY + SSL-Detect tests + testProxyFlags (MultiSocket::Flag::proxy, + MultiSocket::Flag::proxy | MultiSocket::Flag::ssl); + + testProxyFlags (MultiSocket::Flag::proxy | MultiSocket::Flag::ssl, + MultiSocket::Flag::proxy | MultiSocket::Flag::ssl); } void runTest () diff --git a/modules/ripple_asio/sockets/ripple_MultiSocketType.h b/modules/ripple_asio/sockets/ripple_MultiSocketType.h index 16e4f43e57..58ccca8d01 100644 --- a/modules/ripple_asio/sockets/ripple_MultiSocketType.h +++ b/modules/ripple_asio/sockets/ripple_MultiSocketType.h @@ -14,16 +14,6 @@ class MultiSocketType : public MultiSocket { protected: -#if BEAST_ASIO_HAS_BUFFEREDHANDSHAKE - // For some reason, the ssl::stream buffered handshaking - // routines don't seem to work right. Setting this bool - // to false will bypass the ssl::stream buffered handshaking - // API and just resort to our PrefilledReadStream which works - // great and we need it anyway to handle the non-buffered case. - // - static bool const useBufferedHandshake = false; -#endif - typedef boost::system::error_code error_code; public: @@ -34,17 +24,33 @@ public: template explicit MultiSocketType (Arg& arg, int flags = 0) - : m_flags (cleaned_flags (flags)) - , m_needs_handshake (false) - , m_needs_shutdown (false) - , m_needs_shutdown_stream (false) + : m_flags (flags) + , m_state (stateNone) + , m_stream (nullptr) + , m_needsShutdown (false) , m_next_layer (arg) - , m_stream (new_stream ()) , m_strand (get_io_service ()) { + // See if our flags allow us to go directly + // into the ready state with an active stream. + initState (); } protected: + // Tells us what to do next + // + enum State + { + stateNone, // Uninitialized, unloved. + stateHandshake, // We need a call to handshake() to proceed + stateExpectPROXY, // We expect to see a proxy handshake + stateDetectSSL, // We should detect SSL + stateHandshakeFinal, // Final call to underlying stream handshake() + stateReady // Stream is set and ready to go + }; + + //-------------------------------------------------------------------------- + /** The current stream we are passing everything through. This object gets dynamically created and replaced with other objects as we process the various flags for handshaking. @@ -136,20 +142,28 @@ protected: return stream ().write_some (buffers, ec); } - BEAST_ASIO_INITFN_RESULT_TYPE_MEMBER(HandlerCall, void (error_code, std::size_t)) - async_read_some (MutableBuffers const& buffers, HandlerCall const& handler) + void async_read_some (MutableBuffers const& buffers, SharedHandlerPtr handler) { - return stream ().async_read_some (buffers, m_strand.wrap ( +#if 0 + stream ().async_read_some (buffeers, + BOOST_ASIO_MOVE_CAST(SharedHandlerPtr)(handler)); +#else + stream ().async_read_some (buffers, m_strand.wrap ( boost::bind (handler, boost::asio::placeholders::error, boost::asio::placeholders::bytes_transferred))); +#endif } - BEAST_ASIO_INITFN_RESULT_TYPE_MEMBER(HandlerCall, void (error_code, std::size_t)) - async_write_some (ConstBuffers const& buffers, HandlerCall const& handler) + void async_write_some (ConstBuffers const& buffers, SharedHandlerPtr handler) { - return stream ().async_write_some (buffers, m_strand.wrap ( +#if 0 + stream ().async_write_some (buffers, + BOOST_ASIO_MOVE_CAST(SharedHandlerPtr)(handler)); +#else + stream ().async_write_some (buffers, m_strand.wrap ( boost::bind ( handler, boost::asio::placeholders::error, boost::asio::placeholders::bytes_transferred))); +#endif } //-------------------------------------------------------------------------- @@ -157,93 +171,99 @@ protected: // ssl::stream // - /** Determine if the caller needs to call a handshaking function. */ + /** Determine if the caller needs to call a handshaking function. + This is also used to determine if the handshaking shutdown() + has to be called. + */ bool needs_handshake () { - return m_needs_handshake; + return m_state == stateHandshake || + m_state == stateHandshakeFinal || + m_needsShutdown; } + //-------------------------------------------------------------------------- + error_code handshake (Socket::handshake_type type, error_code& ec) { - if (init_handshake (type, ec)) - return stream ().handshake (type, ec); - return ec; + return handshake (type, ConstBuffers (), ec); } - BEAST_ASIO_INITFN_RESULT_TYPE_MEMBER(HandlerCall, void (error_code)) - async_handshake (handshake_type type, HandlerCall const& handler) - { - bassert (m_context.isNull ()); - - error_code ec; - if (init_handshake (type, ec)) - { - // The original handler has already been wrapped for us - // so set the context and indicate the beginning of the - // composed operation. - // - //handler.beginComposed (); - m_context = handler.getContext (); - return stream ().async_handshake (type, handler); - } - - return get_io_service ().post (HandlerCall (HandlerCall::Post (), - handler, ec)); - } - -#if BEAST_ASIO_HAS_BUFFEREDHANDSHAKE + // We always offer the buffered handshake version even pre-boost 1.54 since + // we need the ability to re-use data for multiple handshake stages anyway. + // error_code handshake (handshake_type type, ConstBuffers const& buffers, error_code& ec) { - if (init_handshake (type, ec)) - return stream ().handshake (type, buffers, ec); - return ec; + return do_handshake (type, buffers, ec); } - BEAST_ASIO_INITFN_RESULT_TYPE_MEMBER(HandlerCall, void (error_code, std::size_t)) - async_handshake (handshake_type type, ConstBuffers const& buffers, - HandlerCall const& handler) + //-------------------------------------------------------------------------- + + void async_handshake (handshake_type type, SharedHandlerPtr handler) { - error_code ec; - if (init_handshake (type, ec)) - return stream ().async_handshake (type, buffers, handler); - return get_io_service ().post (HandlerCall ( - HandlerCall::Post (), handler, ec, 0)); + do_async_handshake (type, ConstBuffers (), handler); } -#endif + + void async_handshake (handshake_type type, + ConstBuffers const& buffers, SharedHandlerPtr handler) + { + do_async_handshake (type, buffers, handler); + } + + void do_async_handshake (handshake_type type, + ConstBuffers const& buffers, SharedHandlerPtr const& handler) + { + return get_io_service ().dispatch (SharedHandlerPtr ( + new AsyncOp (this, type, buffers, handler))); + } + + //-------------------------------------------------------------------------- error_code shutdown (error_code& ec) { - if (m_needs_shutdown) + if (m_needsShutdown) { - if (m_needs_shutdown_stream) - return stream ().shutdown (ec); + // Only do the shutdown if the stream really needs it. + if (stream ().needs_handshake ()) + return ec = stream ().shutdown (ec); + return ec = error_code (); } - return handshake_error (ec); + + return ec = handshake_error (); } - BEAST_ASIO_INITFN_RESULT_TYPE_MEMBER(HandlerCall, void (error_code)) - async_shutdown (HandlerCall const& handler) + void async_shutdown (SharedHandlerPtr handler) { error_code ec; - if (m_needs_shutdown) + + if (m_needsShutdown) { - if (m_needs_shutdown_stream) - return stream ().async_shutdown (handler); + if (stream ().needs_handshake ()) + { + stream ().async_shutdown ( + BOOST_ASIO_MOVE_CAST(SharedHandlerPtr)(handler)); + + return; + } } else { + // Our interface didn't require a shutdown but someone called + // it anyone so generate an error code. + bassertfalse; ec = handshake_error (); } - return get_io_service ().post (HandlerCall (HandlerCall::Post (), - handler, ec)); + get_io_service ().wrap ( + BOOST_ASIO_MOVE_CAST(SharedHandlerPtr)(handler)) + (ec); } //-------------------------------------------------------------------------- // - // Handshake implementation details + // Utilities // //-------------------------------------------------------------------------- @@ -274,6 +294,212 @@ protected: return flags; } + bool is_server () const noexcept + { + return m_flags.set (Flag::server_role); + } + + bool is_client () const noexcept + { + return m_flags.set (Flag::client_role); + } + + bool is_peer () const noexcept + { + bassert (is_client () || is_server () || (m_flags == Flag (Flag::peer))); + return m_flags == Flag (Flag::peer); + } + + // Bottleneck to indicate a failed handshake. + // + error_code handshake_error () + { + // VFALCO TODO maybe use a ripple error category? + // set this to something custom that we can recognize later? + // + m_needsShutdown = false; + return boost::asio::error::invalid_argument; + } + + error_code handshake_error (error_code& ec) + { + return ec = handshake_error (); + } + + //-------------------------------------------------------------------------- + // + // State Machine + // + //-------------------------------------------------------------------------- + + // Initialize the current state based on the flags. This is + // called from the constructor. It is possible that a state + // cannot be determined until the handshake type is known, + // in which case we will leave the state at stateNone and the + // current stream set to nullptr + // + void initState () + { + // Clean our flags up + m_flags = cleaned_flags (m_flags); + + if (is_client ()) + { + if (m_flags.set (Flag::proxy)) + { + if (m_flags.set (Flag::ssl)) + m_state = stateHandshake; + else + m_state = stateReady; + + // Client sends PROXY in the plain so make + // sure they have an underlying stream right away. + m_stream = new_plain_stream (); + } + else if (m_flags.set (Flag::ssl)) + { + m_state = stateHandshakeFinal; + m_stream = nullptr; + } + else + { + m_state = stateReady; + m_stream = new_plain_stream (); + } + } + else if (is_server ()) + { + if (m_flags.set (Flag::proxy)) + { + // We expect a PROXY handshake. + // Create the plain stream at handshake time. + m_state = stateHandshake; + m_stream = nullptr; + } + else if (m_flags.set (Flag::ssl_required)) + { + // We require an SSL handshake. + // Create the stream at handshake time. + m_state = stateHandshakeFinal; + m_stream = nullptr; + } + else if (m_flags.set (Flag::ssl)) + { + // We will use the SSL detector at handshake + // time decide which type of stream to create. + m_state = stateHandshake; + m_stream = nullptr; + } + else + { + // No handshaking required. + m_state = stateReady; + m_stream = new_plain_stream (); + } + } + else + { + m_state = stateReady; + m_stream = new_plain_stream (); + } + + // We only set this to true in stateHandshake and + // after the handshake completes without an error. + // + m_needsShutdown = false; + } + + //-------------------------------------------------------------------------- + + // Used for the non-buffered handshake functions. + // + error_code initHandshake (handshake_type type) + { + return initHandshake (type, ConstBuffers ()); + } + + // Updates the state based on the now-known handshake type. The + // buffers parameter contains bytes that have already been received. + // This can come from the results of SSL detection, or from the buffered + // handshake API calls added in Boost v1.54.0. + // + error_code initHandshake (handshake_type type, ConstBuffers const& buffers) + { + switch (m_state) + { + case stateExpectPROXY: + case stateDetectSSL: + m_state = stateHandshake; + case stateHandshake: + case stateHandshakeFinal: + break; + case stateNone: + case stateReady: + default: + // Didn't need handshake, but someone called us anyway? + fatal_error ("invalid state"); + return handshake_error (); + } + + // Peer roles cannot handshake. + if (m_flags == Flag (Flag::peer)) + return handshake_error (); + + // Handshake type must match the role flags + if ( (type == client && ! is_client ()) || + (type == server && ! is_server ())) + return handshake_error (); + + if (is_client ()) + { + // if PROXY flag is set, then it should have already + // been sent in the clear before calling handshake() + // so strip the flag away. + // + m_flags = m_flags.without (Flag::proxy); + + // Someone forgot to call needs_handshake + if (! m_flags.set (Flag::ssl)) + return handshake_error (); + + m_state = stateHandshakeFinal; + m_stream = new_ssl_stream (buffers); + } + else + { + bassert (is_server ()); + + if (m_flags.set (Flag::proxy)) + { + // We will expect and consume a PROXY handshake, + // then come back here with the flag cleared. + m_state = stateExpectPROXY; + m_stream = new_plain_stream (); + } + else if (m_flags.set (Flag::ssl_required)) + { + // WE will perform a required final SSL handshake. + m_state = stateHandshakeFinal; + m_stream = new_ssl_stream (buffers); + } + else if (m_flags.set (Flag::ssl)) + { + // We will use the SSL detector to update + // our flags and come back through here. + m_state = stateDetectSSL; + m_stream = nullptr; + } + else + { + // Done with auto-detect + m_state = stateReady; + m_stream = new_plain_stream (buffers); + } + } + + return error_code (); + } + //-------------------------------------------------------------------------- // Create a plain stream that just wraps the next layer. @@ -324,240 +550,6 @@ protected: return new_ssl_stream (); } - // Creates a stream that detects whether or not an - // SSL handshake is being attempted. - // - Socket* new_detect_ssl_stream () - { - SSL3DetectCallback* callback = new SSL3DetectCallback (this); - typedef HandshakeDetectStreamType this_type; - return new SocketWrapper (callback, m_next_layer); - } - - // Creates an ssl detect stream and front-loads it with data. - // - template - Socket* new_detect_ssl_stream (ConstBufferSequence const& buffers) - { - if (boost::asio::buffer_size (buffers) > 0) - { - SSL3DetectCallback* callback = new SSL3DetectCallback (this); - typedef HandshakeDetectStreamType this_type; - SocketWrapper * const socket = - new SocketWrapper (callback, m_next_layer); - socket->this_layer ().fill (buffers); - return socket; - } - return new_detect_ssl_stream (); - } - - // Creates a stream that detects whether or not a PROXY - // handshake is being sent. - // - template - Socket* new_proxy_stream (ConstBufferSequence const& buffers) - { - // Doesn't make sense to already have data if we're expecting PROXY - check_precondition (boost::asio::buffer_size (buffers) == 0); - - PROXYDetectCallback* callback = new PROXYDetectCallback (this); - typedef HandshakeDetectStreamType this_type; - return new SocketWrapper (callback, m_next_layer); - } - - //-------------------------------------------------------------------------- - - // Called at construction time, this returns a stream appropriate for - // the current settings. If the return value is nullptr, it means that - // the stream cannot be determined until handshake or async_handshake - // is called. In that case, init_handshake will set m_stream. - // - // This also sets the m_needs_handshake and m_needs_shutdown flags - // - Socket* new_stream () - { - Socket* socket = nullptr; - - if (is_client ()) - { - if (m_flags.set (Flag::proxy)) - { - // Client sends PROXY in the plain so make - // sure they have an underlying stream right away. - socket = new_plain_stream (); - m_needs_handshake = m_flags.set (Flag::ssl); - } - else if (m_flags.set (Flag::ssl)) - { - socket = nullptr; - m_needs_handshake = true; - } - else - { - socket = new_plain_stream (); - m_needs_handshake = false; - } - } - else if (is_server ()) - { - if (m_flags.set (Flag::proxy)) - { - // PROXY comes first in the plain - socket = nullptr; - m_needs_handshake = true; - } - else if (m_flags.set (Flag::ssl_required)) - { - socket = nullptr; - m_needs_handshake = true; - } - else if (m_flags.set (Flag::ssl)) - { - socket = nullptr; - m_needs_handshake = true; - } - else - { - m_needs_handshake = false; - socket = new_plain_stream (); - } - } - else - { - // peer, plain stream, no handshake - m_needs_handshake = false; - socket = new_plain_stream (); - } - - // We only set this to true when the handshake succeeds. - m_needs_shutdown = false; - - return socket; - } - - //-------------------------------------------------------------------------- - - // Bottleneck to indicate a failed handshake. - // - error_code handshake_error () - { - // VFALCO TODO maybe use a ripple error category? - // set this to something custom that we can recognize later? - // - m_needs_shutdown = false; - return boost::asio::error::invalid_argument; - } - - error_code handshake_error (error_code& ec) - { - return ec = handshake_error (); - } - - // Sanity checks and cleans up our flags based on the desired handshake - // type and correctly sets the current stream based on the settings. - // If the requested handshake type is incompatible with the socket flags - // set on construction, an error is returned. - // - // This also clears the m_needs_handshake flag on an error. - // - // Returns true if the stream was set. - // - bool init_handshake (handshake_type type, - error_code& ec, ConstBuffers const& buffers = ConstBuffers()) - { - bassert (m_needs_handshake); - - // Peer roles cannot handshake. - if (m_flags == Flag (Flag::peer)) - { - handshake_error (ec); - return false; - } - - // Handshake type must match the role flags - if ( (type == client && ! is_client ()) || - (type == server && ! is_server ())) - { - handshake_error (ec); - return false; - } - - ec = error_code (); - - if (is_client ()) - { - // PROXY sent before handshake, in plain - m_flags = m_flags.without (Flag::proxy); - - if (! m_flags.set (Flag::ssl)) - { - handshake_error (ec); - return false; - } - - m_stream = new_ssl_stream (buffers); - m_needs_shutdown_stream = true; - } - else - { - bassert (is_server ()); - - if (m_flags.set (Flag::proxy)) - { - m_stream = new_proxy_stream (ConstBuffers ()); - } - else if (m_flags.set (Flag::ssl_required)) - { - m_stream = new_ssl_stream (buffers); - m_needs_shutdown_stream = true; - } - else if (m_flags.set (Flag::ssl)) - { - m_stream = new_detect_ssl_stream (buffers); - } - else - { - // This happens after a proxy handshake - m_stream = new_plain_stream (buffers); - } - } - - m_needs_shutdown = true; - - return true; - } - - //-------------------------------------------------------------------------- - // - // PROXY callback - // - class PROXYDetectCallback - : public HandshakeDetectStream ::Callback - { - public: - typedef HandshakeDetectLogicPROXY LogicType; - - explicit PROXYDetectCallback (MultiSocketType * owner) - : m_owner (owner) - { - } - - void on_detect (LogicType& logic, - error_code& ec, ConstBuffers const& buffers) - { - m_owner->on_detect_proxy (logic, ec, buffers); - } - - void on_async_detect (LogicType& logic, error_code const& ec, - ConstBuffers const& buffers, HandlerCall const& origHandler) - { - m_owner->on_async_detect_proxy (logic, ec, buffers, origHandler); - } - - private: - MultiSocketType * m_owner; - }; - //-------------------------------------------------------------------------- void setProxyInfo (HandshakeDetectLogicPROXY::ProxyInfo const proxyInfo) @@ -565,172 +557,318 @@ protected: // Do something with it } - void on_detect_proxy (HandshakeDetectLogicPROXY& logic, - error_code& ec, ConstBuffers const& buffers) + //-------------------------------------------------------------------------- + // + // Synchronous handshake operation + // + + error_code do_handshake (handshake_type type, + ConstBuffers const& buffers, error_code& ec) { - bassert (is_server ()); - + ec = initHandshake (type); if (ec) - return; + return ec; - if (! logic.success ()) + // How can we be ready if a handshake is needed? + bassert (m_state != stateReady); + + // Prepare our rolling detect buffer with any input + boost::asio::streambuf buffer; + buffer.commit (boost::asio::buffer_copy ( + buffer.prepare (buffers.size ()), buffers)); + + // Run a loop of processing and detecting handshakes + // layer after layer until we arrive at the ready state + // with a final stream. + // + do { - handshake_error (ec); - return; + switch (m_state) + { + case stateHandshakeFinal: + { + // A 'real' final handshake on the stream is needed. + m_state = stateReady; + stream ().handshake (type, ec); + } + break; + + case stateExpectPROXY: + { + typedef HandshakeDetectorType < + next_layer_type, HandshakeDetectLogicPROXY> op_type; + + op_type op; + + ec = op.detect (m_next_layer, buffer); + + if (! ec) + { + bassert (op.getLogic ().finished ()); + + if (op.getLogic ().success ()) + { + setProxyInfo (op.getLogic ().getInfo ()); + + // Strip off the PROXY flag. + m_flags = m_flags.without (Flag::proxy); + + // Update handshake state with the leftover bytes. + ec = initHandshake (type, buffer.data ()); + + // buffer input sequence intentionally untouched + } + else + { + // Didn't get the PROXY handshake we needed + ec = handshake_error (); + } + } + } + break; + + case stateDetectSSL: + { + typedef HandshakeDetectorType < + next_layer_type, HandshakeDetectLogicSSL3> op_type; + + op_type op; + + ec = op.detect (m_next_layer, buffer); + + if (! ec) + { + bassert (op.getLogic ().finished ()); + + // Was it SSL? + if (op.getLogic ().success ()) + { + // Convert the ssl flag to ssl_required + m_flags = m_flags.with (Flag::ssl_required).without (Flag::ssl); + } + else + { + // Not SSL, strip the ssl flag + m_flags = m_flags.without (Flag::ssl); + } + + // Update handshake state with the leftover bytes. + ec = initHandshake (type, buffer.data ()); + + // buffer input sequence intentionally untouched + } + } + break; + + case stateNone: + case stateReady: + case stateHandshake: + default: + fatal_error ("invalid state"); + ec = handshake_error (); + break; + } + + if (ec) + break; + } + while (m_state != stateReady); + + if (! ec) + { + // We should be in the ready state now. + bassert (m_state == stateReady); + + // Always need shutdown if handshake successful. + m_needsShutdown = true; } - // Remove the proxy flag - m_flags = m_flags.without (Flag::proxy); - - setProxyInfo (logic.getInfo ()); - - if (init_handshake (Socket::server, ec, buffers)) - { - if (stream ().needs_handshake ()) - stream ().handshake (Socket::server, ec); - } + return ec; } - // origHandler cannot be a reference since it gets move-copied - void on_async_detect_proxy (HandshakeDetectLogicPROXY& logic, - error_code ec, ConstBuffers const& buffers, HandlerCall origHandler) + //-------------------------------------------------------------------------- + // + // Composed asynchronous handshake operator + // + + struct AsyncOp : ComposedAsyncOperation { - if (! ec) + typedef SharedHandlerAllocator Allocator; + + AsyncOp (MultiSocketType * socket, + handshake_type type, ConstBuffers const& buffers, + SharedHandlerPtr const& handler) + : ComposedAsyncOperation (sizeof (*this), handler) + , m_handler (handler) + , m_socket (socket) + , m_type (type) + , m_buffer (std::numeric_limits ::max(), handler) + , m_running (false) { - // Failed to receive PROXY handshake - if (! logic.success ()) - handshake_error (ec); + // Prepare our rolling detect buffer with any input + // + // NOTE We have to do this in reverse order from the synchronous + // version because the callers buffers won't be in scope inside + // our call. + // + m_buffer.commit (boost::asio::buffer_copy ( + m_buffer.prepare (buffers.size ()), buffers)); } - if (! ec) + // Set breakpoint to prove it gets destroyed + ~AsyncOp () { - // Remove the proxy flag from our flag set - m_flags = m_flags.without (Flag::proxy); + } - setProxyInfo (logic.getInfo ()); + // This is the entry point into the composed operation + // + void operator() () + { + m_running = true; - if (init_handshake (Socket::server, ec, buffers)) + error_code ec (m_socket->initHandshake (m_type)); + + if (! ec) { - if (stream ().needs_handshake ()) + if (m_socket->m_state != stateReady) { - stream ().async_handshake (Socket::server, origHandler); + (*this) (ec); return; } + + // Always need shutdown if handshake successful. + m_socket->m_needsShutdown = true; } + + // Call the original handler with the error code and end. + m_socket->get_io_service ().wrap ( + BOOST_ASIO_MOVE_CAST (SharedHandlerPtr)(m_handler)) + (ec); } - get_io_service ().post (HandlerCall (HandlerCall::Post (), - BOOST_ASIO_MOVE_CAST(HandlerCall)(origHandler), ec)); - } - - //-------------------------------------------------------------------------- - - void on_detect_ssl (HandshakeDetectLogicSSL3& logic, - error_code& ec, ConstBuffers const& buffers) - { - bassert (is_server ()); - - if (! ec) + // This implements the asynchronous version of the loop found + // in do_handshake. It gets itself called repeatedly until the + // state resolves to a final handshake or an error occurs. + // + void operator() (error_code const& ec_) { - if (logic.success ()) + error_code ec (ec_); + + do { - m_stream = new_ssl_stream (buffers); - m_needs_shutdown_stream = true; + if (m_socket->m_state == stateReady) + { + // Always need shutdown if handshake successful. + m_socket->m_needsShutdown = true; + break; + } - stream ().handshake (Socket::server, ec); - return; + switch (m_socket->m_state) + { + case stateHandshakeFinal: + { + // Have to set this beforehand even + // though we might get an error. + // + m_socket->m_state = stateReady; + m_socket->stream ().async_handshake (m_type, + SharedHandlerPtr (this)); + } + return; + + case stateExpectPROXY: + { + if (m_proxy.getLogic ().finished ()) + { + if (m_proxy.getLogic ().success ()) + { + m_socket->setProxyInfo (m_proxy.getLogic ().getInfo ()); + + // Strip off the PROXY flag. + m_socket->m_flags = m_socket->m_flags.without (Flag::proxy); + + // Update handshake state with the leftover bytes. + ec = m_socket->initHandshake (m_type, m_buffer.data ()); + break; + } + + // Didn't get the PROXY handshake we needed + ec = m_socket->handshake_error (); + break; + } + + m_proxy.async_detect (m_socket->next_layer (), + m_buffer, SharedHandlerPtr (this)); + } + return; + + case stateDetectSSL: + { + if (m_ssl.getLogic ().finished ()) + { + // Was it SSL? + if (m_ssl.getLogic ().success ()) + { + // Convert the ssl flag to ssl_required + m_socket->m_flags = m_socket->m_flags.with ( + Flag::ssl_required).without (Flag::ssl); + } + else + { + // Not SSL, strip the ssl flag + m_socket->m_flags = m_socket->m_flags.without (Flag::ssl); + } + + // Update handshake state with the leftover bytes. + ec = m_socket->initHandshake (m_type, m_buffer.data ()); + break; + } + + m_ssl.async_detect (m_socket->next_layer (), + m_buffer, SharedHandlerPtr (this)); + } + return; + + case stateNone: + case stateReady: + case stateHandshake: + default: + fatal_error ("invalid state"); + ec = m_socket->handshake_error (); + } } + while (! ec); - m_stream = new_plain_stream (buffers); - } - } + bassert (ec || (m_socket->m_state == stateReady && m_socket->m_needsShutdown)); - // origHandler cannot be a reference since it gets move-copied - void on_async_detect_ssl (HandshakeDetectLogicSSL3& logic, - error_code const& ec, ConstBuffers const& buffers, HandlerCall origHandler) - { - if (! ec) - { - if (logic.success ()) - { - m_stream = new_ssl_stream (buffers); - m_needs_shutdown_stream = true; - // Is this a continuation? - stream ().async_handshake (Socket::server, origHandler); - return; - } - - // Note, this assignment will destroy the - // detector, and our allocated callback along with it. - m_stream = new_plain_stream (buffers); + // Call the original handler with the error code and end. + m_socket->get_io_service ().wrap ( + BOOST_ASIO_MOVE_CAST (SharedHandlerPtr)(m_handler)) + (ec); } - // Is this a continuation? - get_io_service ().post (HandlerCall (HandlerCall::Post (), - origHandler, ec)); - } - - //-------------------------------------------------------------------------- - // - // SSL3 Callback - // - class SSL3DetectCallback - : public HandshakeDetectStream ::Callback - { - public: - typedef HandshakeDetectLogicSSL3 LogicType; - - explicit SSL3DetectCallback (MultiSocketType * owner) - : m_owner (owner) + bool is_continuation () { - } - - void on_detect (LogicType& logic, error_code& ec, - ConstBuffers const& buffers) - { - m_owner->on_detect_ssl (logic, ec, buffers); - } - - void on_async_detect (LogicType& logic, error_code const& ec, - ConstBuffers const& buffers, HandlerCall const& origHandler) - { - m_owner->on_async_detect_ssl (logic, ec, buffers, origHandler); + return m_running || m_handler->is_continuation (); } private: - MultiSocketType * m_owner; + SharedHandlerPtr m_handler; + MultiSocketType * m_socket; + handshake_type const m_type; + boost::asio::basic_streambuf m_buffer; + HandshakeDetectorType m_proxy; + HandshakeDetectorType m_ssl; + bool m_running; }; - //-------------------------------------------------------------------------- - // - // Utilities - // - - bool is_server () const noexcept - { - return m_flags.set (Flag::server_role); - } - - bool is_client () const noexcept - { - return m_flags.set (Flag::client_role); - } - - bool is_peer () const noexcept - { - bassert (is_client () || is_server () || (m_flags == Flag (Flag::peer))); - return m_flags == Flag (Flag::peer); - } - private: Flag m_flags; - bool m_needs_handshake; - bool m_needs_shutdown; - bool m_needs_shutdown_stream; - StreamSocket m_next_layer; + State m_state; ScopedPointer m_stream; + bool m_needsShutdown; + StreamSocket m_next_layer; boost::asio::io_service::strand m_strand; - HandlerCall::Context m_context; }; #endif