From 0feccfd7b061c66dcf82584065d7296df9afb883 Mon Sep 17 00:00:00 2001 From: Vinnie Falco Date: Fri, 16 Aug 2013 06:23:33 -0700 Subject: [PATCH] MultiSocket with ssl handshake detection --- .../sockets/ripple_MultiSocket.cpp | 51 +- .../ripple_asio/sockets/ripple_MultiSocket.h | 23 +- .../sockets/ripple_MultiSocketType.h | 688 ++++++++++++++---- 3 files changed, 588 insertions(+), 174 deletions(-) diff --git a/modules/ripple_asio/sockets/ripple_MultiSocket.cpp b/modules/ripple_asio/sockets/ripple_MultiSocket.cpp index 11f661cef..9fd2e0ce7 100644 --- a/modules/ripple_asio/sockets/ripple_MultiSocket.cpp +++ b/modules/ripple_asio/sockets/ripple_MultiSocket.cpp @@ -4,35 +4,36 @@ */ //============================================================================== -MultiSocket::Options::Options (Flags flags) - : useClientSsl (false) - , enableServerSsl (false) - , requireServerSsl (false) - , requireServerProxy (false) -{ - setFromFlags (flags); -} - -void MultiSocket::Options::setFromFlags (Flags flags) -{ - useClientSsl = (flags & client_ssl) != 0; - enableServerSsl = (flags & (server_ssl | server_ssl_required)) != 0; - requireServerSsl = (flags & server_ssl_required) != 0; - requireServerProxy = (flags & server_proxy) !=0; -} - -//------------------------------------------------------------------------------ - MultiSocket* MultiSocket::New (boost::asio::io_service& io_service, int flags) { return new MultiSocketType (io_service, flags); } +struct RippleTlsContextHolder +{ + RippleTlsContextHolder () + : m_context (RippleTlsContext::New ()) + { + } + + RippleTlsContext& get () + { + return *m_context; + } + +private: + ScopedPointer m_context; +}; + +static RippleTlsContextHolder s_context_holder; + SslContextBase::BoostContextType &MultiSocket::getRippleTlsBoostContext () { - static ScopedPointer context (RippleTlsContext::New ()); + // This doesn't work because VS2012 doesn't wrap + // it with a thread-safety preamble!! + //static ScopedPointer context (RippleTlsContext::New ()); - return context->getBoostContext (); + return s_context_holder.get ().getBoostContext (); } //------------------------------------------------------------------------------ @@ -182,29 +183,21 @@ public: { check_precondition (! MultiSocket::Flag (extraClientFlags).any_set (MultiSocket::Flag::client_role | MultiSocket::Flag::server_role)); -#if 0 run (MultiSocket::Flag::client_role | extraClientFlags, MultiSocket::Flag::server_role | extraServerFlags); -#else - run_async (MultiSocket::Flag::client_role | extraClientFlags, MultiSocket::Flag::server_role | extraServerFlags); -#endif } template void testProtocol () { -#if 0 // These should pass. run (0, 0); run (MultiSocket::Flag::client_role, 0); run (0, MultiSocket::Flag::server_role); run (MultiSocket::Flag::client_role, MultiSocket::Flag::server_role); -#endif -#if 0 // These should pass testFlags (MultiSocket::Flag::ssl, MultiSocket::Flag::ssl_required); testFlags (0, MultiSocket::Flag::ssl); -#endif testFlags (MultiSocket::Flag::ssl, MultiSocket::Flag::ssl); } diff --git a/modules/ripple_asio/sockets/ripple_MultiSocket.h b/modules/ripple_asio/sockets/ripple_MultiSocket.h index 7991cb9f3..cb42aa5c5 100644 --- a/modules/ripple_asio/sockets/ripple_MultiSocket.h +++ b/modules/ripple_asio/sockets/ripple_MultiSocket.h @@ -74,33 +74,12 @@ public: server_proxy = 8 }; - struct Options - { - Options (Flags flags = none); - - // Always perform SSL handshake as client role - bool useClientSsl; - - // Enable optional SSL capability as server role - bool enableServerSsl; - - // Require SSL as server role. - // Does not require that enableServerSsl is set - bool requireServerSsl; - - // Require PROXY protocol handshake as server role - bool requireServerProxy; - - private: - void setFromFlags (Flags flags); - }; - static MultiSocket* New (boost::asio::io_service& io_service, int flags = 0); // Ripple uses a SSL/TLS context with specific parameters and this returns // a reference to the corresponding boost::asio::ssl::context object. // - static SslContextBase::BoostContextType &getRippleTlsBoostContext (); + static SslContextBase::BoostContextType& getRippleTlsBoostContext (); }; #endif diff --git a/modules/ripple_asio/sockets/ripple_MultiSocketType.h b/modules/ripple_asio/sockets/ripple_MultiSocketType.h index 607e12182..2f446f75a 100644 --- a/modules/ripple_asio/sockets/ripple_MultiSocketType.h +++ b/modules/ripple_asio/sockets/ripple_MultiSocketType.h @@ -9,18 +9,79 @@ //------------------------------------------------------------------------------ +// should really start using this +namespace detail +{ +} + +//------------------------------------------------------------------------------ + +class ssl_detector_base +{ +protected: + typedef boost::system::error_code error_code; + +public: + /** Called when the state is known. + + This could be called from any thread, most likely an io_service + thread but don't rely on that. + + The Callback must be allocated via operator new. + */ + struct Callback + { + virtual ~Callback () { } + + /** Called for synchronous ssl detection. + + Note that the storage for the buffers passed to the + callback is owned by the detector class and becomes + invalid when the detector class is destroyed, which is + a common thing to do from inside your callback. + + @param ec A modifiable error code that becomes the return + value of handshake. + @param buffers The bytes that were read in. + @param is_ssl True if the sequence is an ssl handshake. + */ + virtual void on_detect (error_code& ec, + ConstBuffers const& buffers, bool is_ssl) = 0; + + virtual void on_async_detect (error_code const& ec, + ConstBuffers const& buffers, ErrorCall const& origHandler, + bool is_ssl) = 0; + + #if BEAST_ASIO_HAS_BUFFEREDHANDSHAKE + virtual void on_async_detect (error_code const& ec, + ConstBuffers const& buffers, TransferCall const& origHandler, + bool is_ssl) = 0; + #endif + + }; +}; + +/** A stream that can detect an SSL handshake. +*/ template class ssl_detector - : public boost::asio::ssl::stream_base + : public ssl_detector_base + , public boost::asio::ssl::stream_base , public boost::asio::socket_base { private: + typedef ssl_detector this_type; + typedef boost::asio::streambuf buffer_type; typedef typename boost::remove_reference ::type stream_type; public: + /** This takes ownership of the callback. + The callback must be allocated with operator new. + */ template - ssl_detector (Arg& arg) - : m_next_layer (arg) + ssl_detector (Callback* callback, Arg& arg) + : m_callback (callback) + , m_next_layer (arg) , m_stream (m_next_layer) { } @@ -31,7 +92,7 @@ public: { return m_next_layer.get_io_service (); } - + // basic_socket typedef typename stream_type::protocol_type protocol_type; @@ -49,50 +110,189 @@ public: // ssl::stream - boost::system::error_code handshake (handshake_type type, boost::system::error_code& ec) + error_code handshake (handshake_type type, error_code& ec) { - ec = boost::system::error_code (); - - m_buffer.commit (boost::asio::read (m_next_layer, - m_buffer.prepare (m_detector.max_needed ()), ec)); - - if (!ec) - { - bool const finished = m_detector.analyze (m_buffer.data ()); - - if (finished) - { - } - } - - return ec; + return do_handshake (type, ec, ConstBuffers ()); } template - BOOST_ASIO_INITFN_RESULT_TYPE(HandshakeHandler, void (boost::system::error_code)) - async_handshake(handshake_type type, BOOST_ASIO_MOVE_ARG(HandshakeHandler) handler) + BOOST_ASIO_INITFN_RESULT_TYPE(HandshakeHandler, void (error_code)) + async_handshake (handshake_type type, BOOST_ASIO_MOVE_ARG(HandshakeHandler) handler) { +#if BEAST_ASIO_HAS_FUTURE_RETURNS + boost::asio::detail::async_result_init< + ErrorCall, void (error_code)> init( + BOOST_ASIO_MOVE_CAST(HandshakeHandler)(handler)); + // init.handler is copied + m_origHandler = ErrorCall (HandshakeHandler(init.handler)); + bassert (m_origBufferedHandler.isNull ()); + async_do_handshake (type, ConstBuffers ()); + return init.result.get(); +#else + m_origHandler = ErrorCall (handler); + bassert (m_origBufferedHandler.isNull ()); + async_do_handshake (type, ConstBuffers ()); +#endif } #if BEAST_ASIO_HAS_BUFFEREDHANDSHAKE template - void handshake (handshake_type type, const ConstBufferSequence& buffers) + error_code handshake (handshake_type type, + const ConstBufferSequence& buffers, error_code& ec) { + return do_handshake (type, ec, ConstBuffers (buffers));; } template - BOOST_ASIO_INITFN_RESULT_TYPE(BufferedHandshakeHandler, void (boost::system::error_code, std::size_t)) + BOOST_ASIO_INITFN_RESULT_TYPE(BufferedHandshakeHandler, void (error_code, std::size_t)) async_handshake(handshake_type type, const ConstBufferSequence& buffers, BOOST_ASIO_MOVE_ARG(BufferedHandshakeHandler) handler) { +#if BEAST_ASIO_HAS_FUTURE_RETURNS + boost::asio::detail::async_result_init< + BufferedHandshakeHandler, void (error_code, std::size_t)> init( + BOOST_ASIO_MOVE_CAST(BufferedHandshakeHandler)(handler)); + // init.handler is copied + m_origBufferedHandler = TransferCall (BufferedHandshakeHandler(init.handler)); + bassert (m_origHandler.isNull ()); + async_do_handshake (type, ConstBuffers (buffers)); + return init.result.get(); +#else + m_origBufferedHandler = TransferCall (handler); + bassert (m_origHandler.isNull ()); + async_do_handshake (type, ConstBuffers (buffers)); +#endif } #endif + //-------------------------------------------------------------------------- + + error_code do_handshake (handshake_type, error_code& ec, ConstBuffers const& buffers) + { + ec = error_code (); + + // Transfer caller data to our buffer. + m_buffer.commit (boost::asio::buffer_copy (m_buffer.prepare ( + boost::asio::buffer_size (buffers)), buffers)); + + do + { + std::size_t const available = m_buffer.size (); + std::size_t const needed = m_detector.max_needed (); + if (available < needed) + { + buffer_type::mutable_buffers_type buffers ( + m_buffer.prepare (needed - available)); + m_buffer.commit (m_next_layer.read_some (buffers, ec)); + } + + if (! ec) + { + m_detector.analyze (m_buffer.data ()); + + if (m_detector.finished ()) + { + m_callback->on_detect (ec, + ConstBuffers (m_buffer.data ()), + m_detector.success ()); + break; + } + + // If this fails it means we will never finish + check_postcondition (available < needed); + } + } + while (! ec); + + return ec; + } + + //-------------------------------------------------------------------------- + + void async_do_handshake (handshake_type type, ConstBuffers const& buffers) + { + // Transfer caller data to our buffer. + std::size_t const bytes_transferred (boost::asio::buffer_copy ( + m_buffer.prepare (boost::asio::buffer_size (buffers)), buffers)); + + // bootstrap the asynchronous loop + on_async_read_some (error_code (), bytes_transferred); + } + + // asynchronous version of the synchronous loop found in handshake () + // + void on_async_read_some (error_code const& ec, std::size_t bytes_transferred) + { + if (! ec) + { + m_buffer.commit (bytes_transferred); + + std::size_t const available = m_buffer.size (); + std::size_t const needed = m_detector.max_needed (); + + m_detector.analyze (m_buffer.data ()); + + if (m_detector.finished ()) + { + #if BEAST_ASIO_HAS_BUFFEREDHANDSHAKE + if (! m_origBufferedHandler.isNull ()) + { + bassert (m_origHandler.isNull ()); + // continuation? + m_callback->on_async_detect (ec, ConstBuffers (m_buffer.data ()), + m_origBufferedHandler, m_detector.success ()); + return; + } + #endif + + bassert (! m_origHandler.isNull ()) + // continuation? + m_callback->on_async_detect (ec, + ConstBuffers (m_buffer.data ()), m_origHandler, + m_detector.success ()); + return; + } + + // If this fails it means we will never finish + check_postcondition (available < needed); + + buffer_type::mutable_buffers_type buffers (m_buffer.prepare ( + needed - available)); + + // need a continuation hook here? + m_next_layer.async_read_some (buffers, boost::bind ( + &this_type::on_async_read_some, this, boost::asio::placeholders::error, + boost::asio::placeholders::bytes_transferred)); + + return; + } + + #if BEAST_ASIO_HAS_BUFFEREDHANDSHAKE + if (! m_origBufferedHandler.isNull ()) + { + bassert (m_origHandler.isNull ()); + // continuation? + m_callback->on_async_detect (ec, ConstBuffers (m_buffer.data ()), + m_origBufferedHandler, false); + return; + } + #endif + + bassert (! m_origHandler.isNull ()) + // continuation? + m_callback->on_async_detect (ec, + ConstBuffers (m_buffer.data ()), m_origHandler, + false); + } + private: + ScopedPointer m_callback; Stream m_next_layer; - boost::asio::streambuf m_buffer; + buffer_type m_buffer; boost::asio::buffered_read_stream m_stream; HandshakeDetectorType m_detector; + ErrorCall m_origHandler; + TransferCall m_origBufferedHandler; }; //------------------------------------------------------------------------------ @@ -103,22 +303,41 @@ template class MultiSocketType : public MultiSocket { +protected: +#if BEAST_ASIO_HAS_BUFFEREDHANDSHAKE + // For some reason, the ssl::stream buffered handshaking + // routines don't seem to work right. Setting this bool + // to false will bypass the ssl::stream buffered handshaking + // API and just resort to our PrefilledReadStream which works + // great and we need it anyway to handle the non-buffered case. + // + static bool const useBufferedHandshake = false; +#endif + + typedef boost::system::error_code error_code; + public: typedef typename boost::remove_reference ::type next_layer_type; typedef typename boost::add_reference ::type next_layer_type_ref; typedef typename next_layer_type::lowest_layer_type lowest_layer_type; typedef typename boost::asio::ssl::stream ssl_stream_type; + // These types are used internaly + // + + // The type of stream that the ssl_detector uses + typedef next_layer_type& ssl_detector_stream_type; + template - explicit MultiSocketType (Arg& arg, int flags = 0, Options const& options = Options ()) - : m_options (options) - , m_flags (cleaned_flags (flags)) + explicit MultiSocketType (Arg& arg, int flags = 0) + : m_flags (cleaned_flags (flags)) , m_next_layer (arg) - , m_stream (new_current_stream (true)) + , m_stream (new_stream ()) , m_strand (get_io_service ()) { } +protected: Socket& stream () const noexcept { bassert (m_stream != nullptr); @@ -176,17 +395,17 @@ public: return nullptr; } - boost::system::error_code cancel (boost::system::error_code& ec) + error_code cancel (error_code& ec) { return lowest_layer ().cancel (ec); } - boost::system::error_code shutdown (Socket::shutdown_type what, boost::system::error_code& ec) + error_code shutdown (Socket::shutdown_type what, error_code& ec) { return lowest_layer ().shutdown (what, ec); } - boost::system::error_code close (boost::system::error_code& ec) + error_code close (error_code& ec) { return lowest_layer ().close (ec); } @@ -196,17 +415,17 @@ public: // basic_stream_socket // - std::size_t read_some (MutableBuffers const& buffers, boost::system::error_code& ec) + std::size_t read_some (MutableBuffers const& buffers, error_code& ec) { return stream ().read_some (buffers, ec); } - std::size_t write_some (ConstBuffers const& buffers, boost::system::error_code& ec) + std::size_t write_some (ConstBuffers const& buffers, error_code& ec) { return stream ().write_some (buffers, ec); } - BEAST_ASIO_INITFN_RESULT_TYPE_MEMBER(TransferCall, void (boost::system::error_code, std::size_t)) + BEAST_ASIO_INITFN_RESULT_TYPE_MEMBER(TransferCall, void (error_code, std::size_t)) async_read_some (MutableBuffers const& buffers, BOOST_ASIO_MOVE_ARG(TransferCall) handler) { return stream ().async_read_some (buffers, m_strand.wrap (boost::bind ( @@ -214,7 +433,7 @@ public: boost::asio::placeholders::bytes_transferred))); } - BEAST_ASIO_INITFN_RESULT_TYPE_MEMBER(TransferCall, void (boost::system::error_code, std::size_t)) + BEAST_ASIO_INITFN_RESULT_TYPE_MEMBER(TransferCall, void (error_code, std::size_t)) async_write_some (ConstBuffers const& buffers, BOOST_ASIO_MOVE_ARG(TransferCall) handler) { return stream ().async_write_some (buffers, m_strand.wrap (boost::bind ( @@ -229,72 +448,70 @@ public: bool requires_handshake () { - return stream ().requires_handshake (); + if (m_stream != nullptr) + return stream ().requires_handshake (); + return true; } - boost::system::error_code handshake (Socket::handshake_type type, boost::system::error_code& ec) + error_code handshake (Socket::handshake_type type, error_code& ec) { - if (! prepare_handshake (type, ec)) + if (set_handshake_stream (type, ec)) return stream ().handshake (type, ec); - return ec; } - BEAST_ASIO_INITFN_RESULT_TYPE_MEMBER(ErrorCall, void (boost::system::error_code)) + BEAST_ASIO_INITFN_RESULT_TYPE_MEMBER(ErrorCall, void (error_code)) async_handshake (handshake_type type, BOOST_ASIO_MOVE_ARG(ErrorCall) handler) { - boost::system::error_code ec; - - if (! prepare_handshake (type, ec)) - return stream ().async_handshake (type, BOOST_ASIO_MOVE_CAST(ErrorCall)(handler)); - - return get_io_service ().post (boost::bind ( - BOOST_ASIO_MOVE_CAST(ErrorCall)(handler), ec)); + error_code ec; + if (set_handshake_stream (type, ec)) + return stream ().async_handshake (type, + BOOST_ASIO_MOVE_CAST(ErrorCall)(handler)); + return get_io_service ().post (CompletionCall (handler, ec)); } #if BEAST_ASIO_HAS_BUFFEREDHANDSHAKE - boost::system::error_code handshake (handshake_type type, - ConstBuffers const& buffers, boost::system::error_code& ec) + error_code handshake (handshake_type type, + ConstBuffers const& buffers, error_code& ec) { - if (! prepare_handshake (type, ec)) + if (set_handshake_stream (type, ec)) return stream ().handshake (type, buffers, ec); - return ec; } - BEAST_ASIO_INITFN_RESULT_TYPE_MEMBER(TransferCall, void (boost::system::error_code, std::size_t)) + BEAST_ASIO_INITFN_RESULT_TYPE_MEMBER(TransferCall, void (error_code, std::size_t)) async_handshake (handshake_type type, ConstBuffers const& buffers, BOOST_ASIO_MOVE_ARG(TransferCall) handler) { - boost::system::error_code ec; - - if (! prepare_handshake (type, ec)) + error_code ec; + if (set_handshake_stream (type, ec)) return stream ().async_handshake (type, buffers, BOOST_ASIO_MOVE_CAST(TransferCall)(handler)); - - return get_io_service ().post (boost::bind ( - BOOST_ASIO_MOVE_CAST(TransferCall)(handler), ec, 0)); + return get_io_service ().post (CompletionCall (handler, ec, 0)); } #endif - boost::system::error_code shutdown (boost::system::error_code& ec) + error_code shutdown (error_code& ec) { + bassert (requires_handshake ()); + if (stream ().requires_handshake ()) return stream ().shutdown (ec); return handshake_error (ec); } - BEAST_ASIO_INITFN_RESULT_TYPE_MEMBER(ErrorCall, void (boost::system::error_code)) + BEAST_ASIO_INITFN_RESULT_TYPE_MEMBER(ErrorCall, void (error_code)) async_shutdown (BOOST_ASIO_MOVE_ARG(ErrorCall) handler) { + bassert (requires_handshake ()); + if (stream ().requires_handshake ()) return stream ().async_shutdown (BOOST_ASIO_MOVE_CAST(ErrorCall)(handler)); - boost::system::error_code ec; + error_code ec; handshake_error (ec); - return get_io_service ().post (boost::bind ( - BOOST_ASIO_MOVE_CAST(ErrorCall)(handler), ec)); + return get_io_service ().post (CompletionCall (handler, ec)); } //-------------------------------------------------------------------------- @@ -332,23 +549,63 @@ public: //-------------------------------------------------------------------------- + // Create a plain stream that just wraps the next layer. + // Socket* new_plain_stream () { return new SocketWrapper (m_next_layer); } + // Create a plain stream but front-load it with some bytes. + // A copy of the buffers is made. + // + template + Socket* new_plain_stream (ConstBufferSequence const& buffers) + { + using namespace boost; + if (asio::buffer_size (buffers) > 0) + { + typedef PrefilledReadStream this_type; + return new SocketWrapper (m_next_layer, buffers); + } + return new_plain_stream (); + } + + // Creates an ssl stream + // Socket* new_ssl_stream () { - typedef typename boost::asio::ssl::stream ssl_stream_type; - - return new SocketWrapper ( + using namespace boost; + typedef typename asio::ssl::stream this_type; + return new SocketWrapper ( m_next_layer, MultiSocket::getRippleTlsBoostContext ()); } + // Creates an ssl stream, but front-load it with some bytes. + // A copy of the buffers is made. + // + template + Socket* new_ssl_stream (ConstBufferSequence const& buffers) + { + using namespace boost; + if (asio::buffer_size (buffers) > 0) + { + bassert (asio::buffer_size (buffers) > 0); + typedef PrefilledReadStream next_type; + typedef asio::ssl::stream this_type; + SocketWrapper * const socket = new SocketWrapper ( + m_next_layer, MultiSocket::getRippleTlsBoostContext ()); + socket->this_layer().next_layer().fill (buffers); + return socket; + } + return new_ssl_stream (); + } + Socket* new_detect_ssl_stream () { - return new SocketWrapper > (m_next_layer); - //return nullptr; + ssl_detector_callback* callback = new ssl_detector_callback (this); + typedef ssl_detector this_type; + return new SocketWrapper (callback, m_next_layer); } Socket* new_proxy_stream () @@ -356,64 +613,70 @@ public: return nullptr; } - // Creates a Socket representing what the current stream should - // be based on the flags and the state of the handshaking engine. - // - Socket* new_current_stream (bool createPlain) - { - Socket* result = nullptr; + //-------------------------------------------------------------------------- - if (m_flags.set (Flag::client_role)) + // Returns a stream appropriate for the current settings. If the + // return value is nullptr, it means that the stream cannot be determined + // until handshake or async_handshake is called. At that point, it will + // be necessary to use new_handshake_stream + // + Socket* new_stream () + { + Socket* socket = nullptr; + + if (is_client ()) { if (m_flags.set (Flag::proxy)) { - // unsupported function + // sending PROXY unimplemented for now SocketBase::pure_virtual (); + socket = nullptr; } - - if (m_flags.set (Flag::ssl)) + else if (! m_flags.any_set ( + Flag::ssl)) { - result = new_ssl_stream (); + socket = new_plain_stream (); } else { - result = nullptr; + // A call to handshake or async_handshake is + // required to determine the type of stream to create. + // + socket = nullptr; + } + } + else if (is_server ()) + { + // As long as no flags indicate a handshake may be + // required, we can start out with a plain stream. + // + if (! m_flags.any_set ( + Flag::ssl | Flag::ssl_required | Flag::proxy)) + { + socket = new_plain_stream (); + } + else + { + // A call to handshake or async_handshake is + // required to determine the type of stream to create. + // + socket = nullptr; } } else { - if (m_flags.set (Flag::proxy)) - { - result = new_proxy_stream (); - } - else if (m_flags.set (Flag::ssl_required)) - { - result = new_ssl_stream (); - } - else if (m_flags.set (Flag::ssl)) - { - result = new_detect_ssl_stream (); - } - else - { - result = nullptr; - } + // peer, plain stream, no handshake + socket = new_plain_stream (); } - if (createPlain) - { - if (result == nullptr) - result = new_plain_stream (); - } - - return result; + return socket; } //-------------------------------------------------------------------------- // Bottleneck to indicate a failed handshake. // - boost::system::error_code handshake_error (boost::system::error_code& ec) + error_code handshake_error (error_code& ec) { // VFALCO TODO maybe use a ripple error category? // set this to something custom that we can recognize later? @@ -426,38 +689,217 @@ public: // If the requested handshake type is incompatible with the socket flags // set on construction, an error is returned. // - boost::system::error_code prepare_handshake (handshake_type type, - boost::system::error_code& ec) + // Returns true if the stream was set + // + bool set_handshake_stream (handshake_type type, error_code& ec) { - ec = boost::system::error_code (); + // We better require a handshake if this is getting called + bassert (requires_handshake ()); - // If we constructed as 'peer' then become a client - // or a server as needed according to the handshake type. + // If a stream already exists, some logic went wrong. + bassert (m_stream == nullptr); + + ec = error_code (); + + // If we constructed as a peer role then return + // an error since there is no handshake required. // if (m_flags == Flag (Flag::peer)) - m_flags = (type == server) ? - Flag (Flag::server_role) : Flag (Flag::client_role); + { + handshake_error (ec); + return false; + } - // If we constructed with a specific role, the handshake - // must match the role. + // If we constructed in a server or client role and + // the handshake type doesn't match, return an error. // if ( (type == client && ! m_flags.set (Flag::client_role)) || (type == server && ! m_flags.set (Flag::server_role))) - return handshake_error (ec); + { + handshake_error (ec); + return false; + } - // Figure out what type of new stream we need. - // We pass false to avoid creating a redundant plain stream - Socket* socket = new_current_stream (false); + Socket* socket = nullptr; - if (socket != nullptr) - m_stream = socket; + if (m_flags.set (Flag::client_role)) + { + // should have filtered this out already + bassert (! m_flags.set (Flag::proxy)); - return ec; + if (m_flags.set (Flag::ssl)) + { + socket = new_ssl_stream (); + } + else + { + // bad + } + } + else + { + if (m_flags.set (Flag::proxy)) + { + socket = new_proxy_stream (); + } + else if (m_flags.set (Flag::ssl_required)) + { + socket = new_ssl_stream (); + } + else if (m_flags.set (Flag::ssl)) + { + socket = new_detect_ssl_stream (); + } + else + { + // bad + } + } + + check_postcondition (socket != nullptr); + + m_stream = socket; + + return true; + } + + //-------------------------------------------------------------------------- + + void on_detect_ssl (error_code& ec, ConstBuffers const& buffers, bool is_ssl) + { + bassert (is_server ()); + + if (! ec) + { + if (is_ssl) + { + #if BEAST_ASIO_HAS_BUFFEREDHANDSHAKE + if (useBufferedHandshake) + { + m_stream = new_ssl_stream (); + stream ().handshake (Socket::server, buffers, ec); + return; + } + #endif + + m_stream = new_ssl_stream (buffers); + stream ().handshake (Socket::server, ec); + return; + } + + m_stream = new_plain_stream (buffers); + } + } + + // origHandler cannot be a reference since it gets move-copied + void on_async_detect_ssl (error_code const& ec, + ConstBuffers const& buffers, ErrorCall origHandler, bool is_ssl) + { + if (! ec) + { + if (is_ssl) + { + m_stream = new_ssl_stream (buffers); + // Is this a continuation? + stream ().async_handshake (Socket::server, origHandler); + return; + } + + // Note, this assignment will destroy the + // detector, and our allocated callback along with it. + m_stream = new_plain_stream (buffers); + } + + // Is this a continuation? + get_io_service ().post (CompletionCall ( + origHandler, ec)); + } + +#if BEAST_ASIO_HAS_BUFFEREDHANDSHAKE + // origHandler cannot be a reference since it gets move-copied + void on_async_detect_ssl (error_code const& ec, + ConstBuffers const& buffers, TransferCall origHandler, bool is_ssl) + { + std::size_t bytes_transferred = 0; + + if (! ec) + { + if (is_ssl) + { + // Is this a continuation? + m_stream = new_ssl_stream (); + stream ().async_handshake (Socket::server, buffers, origHandler); + return; + } + + // Note, this assignment will destroy the + // detector, and our allocated callback along with it. + m_stream = new_plain_stream (buffers); + + // Since the handshake was not SSL we leave + // bytes_transferred at zero since we didn't use anything. + } + + // Is this a continuation? + get_io_service ().post (CompletionCall ( + origHandler, ec, bytes_transferred)); + } +#endif + + //-------------------------------------------------------------------------- + // + // ssl_detector calback + // + class ssl_detector_callback : public ssl_detector_base::Callback + { + public: + explicit ssl_detector_callback (MultiSocketType * owner) + : m_owner (owner) + { + } + + void on_detect (error_code& ec, + ConstBuffers const& buffers, bool is_ssl) + { + m_owner->on_detect_ssl (ec, buffers, is_ssl); + } + + void on_async_detect (error_code const& ec, + ConstBuffers const& buffers, ErrorCall const& origHandler, + bool is_ssl) + { + m_owner->on_async_detect_ssl (ec, buffers, origHandler, is_ssl); + } + + #if BEAST_ASIO_HAS_BUFFEREDHANDSHAKE + void on_async_detect (error_code const& ec, + ConstBuffers const& buffers, TransferCall const& origHandler, + bool is_ssl) + { + m_owner->on_async_detect_ssl (ec, buffers, origHandler, is_ssl); + } + #endif + + private: + MultiSocketType * m_owner; + }; + + //-------------------------------------------------------------------------- + // + // Utilities + // + + bool is_server () const noexcept + { + return m_flags.set (Flag::server_role); + } + + bool is_client () const noexcept + { + return m_flags.set (Flag::client_role); } private: - Options m_options; // DEPRECATED - Flag m_flags; StreamSocket m_next_layer; ScopedPointer m_stream;