From 7b9a5e8753b09fe39d4b621cecc115ffc48836e7 Mon Sep 17 00:00:00 2001 From: Vinnie Falco Date: Sat, 24 Aug 2013 18:00:08 -0700 Subject: [PATCH] Update MultiSocket for improved Socket --- .../sockets/ripple_MultiSocket.cpp | 16 +- .../ripple_asio/sockets/ripple_MultiSocket.h | 14 +- .../sockets/ripple_MultiSocketType.h | 266 ++++++++++++------ 3 files changed, 211 insertions(+), 85 deletions(-) diff --git a/modules/ripple_asio/sockets/ripple_MultiSocket.cpp b/modules/ripple_asio/sockets/ripple_MultiSocket.cpp index 260602b507..f65fe6eb43 100644 --- a/modules/ripple_asio/sockets/ripple_MultiSocket.cpp +++ b/modules/ripple_asio/sockets/ripple_MultiSocket.cpp @@ -9,6 +9,20 @@ MultiSocket* MultiSocket::New (boost::asio::io_service& io_service, int flags) return new MultiSocketType (io_service, flags); } +MultiSocket* MultiSocket::New (boost::asio::ip::tcp::socket& socket, int flags) +{ + return new MultiSocketType (socket, flags); +} + +MultiSocket* MultiSocket::New ( + boost::asio::ssl::stream < + boost::asio::ip::tcp::socket&>& stream, int flags) +{ + return new MultiSocketType < + boost::asio::ssl::stream < + boost::asio::ip::tcp::socket&>& > (stream, flags); +} + struct RippleTlsContextHolder { RippleTlsContextHolder () @@ -181,7 +195,6 @@ public: TestPeerLogicSyncClient, TestPeerLogicAsyncServer> (clientArg, serverArg, timeoutSeconds).report (*this); - PeerTest::run , TestPeerLogicAsyncClient, TestPeerLogicAsyncServer> (clientArg, serverArg, timeoutSeconds).report (*this); @@ -226,7 +239,6 @@ public: testFlags (MultiSocket::Flag::ssl, MultiSocket::Flag::ssl_required); - // SSL-Detect tests testFlags (0, MultiSocket::Flag::ssl); diff --git a/modules/ripple_asio/sockets/ripple_MultiSocket.h b/modules/ripple_asio/sockets/ripple_MultiSocket.h index cb42aa5c5f..888b78f4ef 100644 --- a/modules/ripple_asio/sockets/ripple_MultiSocket.h +++ b/modules/ripple_asio/sockets/ripple_MultiSocket.h @@ -64,7 +64,7 @@ public: private: int m_flags; }; - + enum Flags { none = 0, @@ -74,7 +74,19 @@ public: server_proxy = 8 }; + virtual SSL* native_handle () = 0; + + /* + // This would be the underlying StreamSocket template parameter + typedef boost::asio::ip::tcp::socket NativeSocketType; + typedef boost::asio::ssl::stream SslStreamType; + virtual boost::asio::ip::tcp::socket& getNativeSocket () = 0; + virtual SslStreamType& getSslStream () = 0; + */ + static MultiSocket* New (boost::asio::io_service& io_service, int flags = 0); + static MultiSocket* New (boost::asio::ip::tcp::socket& socket, int flags = 0); + static MultiSocket* New (boost::asio::ssl::stream & stream, int flags = 0); // Ripple uses a SSL/TLS context with specific parameters and this returns // a reference to the corresponding boost::asio::ssl::context object. diff --git a/modules/ripple_asio/sockets/ripple_MultiSocketType.h b/modules/ripple_asio/sockets/ripple_MultiSocketType.h index 58ccca8d01..fd05eeabb1 100644 --- a/modules/ripple_asio/sockets/ripple_MultiSocketType.h +++ b/modules/ripple_asio/sockets/ripple_MultiSocketType.h @@ -7,6 +7,8 @@ #ifndef RIPPLE_MULTISOCKETTYPE_H_INCLUDED #define RIPPLE_MULTISOCKETTYPE_H_INCLUDED +#define MULTISOCKET_USE_STRAND 1 + /** Template for producing instances of MultiSocket */ template @@ -20,16 +22,16 @@ public: typedef typename boost::remove_reference ::type next_layer_type; typedef typename boost::add_reference ::type next_layer_type_ref; typedef typename next_layer_type::lowest_layer_type lowest_layer_type; - typedef typename boost::asio::ssl::stream ssl_stream_type; template explicit MultiSocketType (Arg& arg, int flags = 0) : m_flags (flags) , m_state (stateNone) + , m_verify_mode (0) , m_stream (nullptr) , m_needsShutdown (false) , m_next_layer (arg) - , m_strand (get_io_service ()) + , m_native_ssl_handle (nullptr) { // See if our flags allow us to go directly // into the ready state with an active stream. @@ -50,6 +52,20 @@ protected: }; //-------------------------------------------------------------------------- + // + // MultiSocket + // + + SSL* native_handle () + { + bassert (m_native_ssl_handle != nullptr); + return m_native_ssl_handle; + } + + //-------------------------------------------------------------------------- + // + // MultiSocketType + // /** The current stream we are passing everything through. This object gets dynamically created and replaced with other @@ -61,6 +77,7 @@ protected: return *m_stream; } +#if 0 next_layer_type& next_layer () noexcept { return m_next_layer; @@ -80,6 +97,30 @@ protected: { return m_next_layer.lowest_layer (); } +#endif + + //-------------------------------------------------------------------------- + // + // Socket + // + + // We will return the underlying StreamSocket + void* this_layer_ptr (char const* type_name) const + { + char const* const name (typeid (next_layer_type).name ()); + if (strcmp (name, type_name) == 0) + return const_cast (static_cast (&m_next_layer)); + return nullptr; + } + + //-------------------------------------------------------------------------- + + // What would we return from here? + bool native_handle (char const*, void*) const + { + pure_virtual_called (__FILE__, __LINE__); + return false; + } //-------------------------------------------------------------------------- // @@ -96,35 +137,29 @@ protected: // basic_socket // - void* lowest_layer (char const* type_name) const + /** Always the lowest_layer of the underlying StreamSocket. */ + void* lowest_layer_ptr (char const* type_name) const { char const* const name (typeid (lowest_layer_type).name ()); if (strcmp (name, type_name) == 0) - return const_cast (static_cast (&lowest_layer ())); - return nullptr; - } - - void* native_handle (char const* type_name) const - { - char const* const name (typeid (next_layer_type).name ()); - if (strcmp (name, type_name) == 0) - return const_cast (static_cast (&next_layer ())); + return const_cast ( + static_cast (&m_next_layer.lowest_layer ())); return nullptr; } error_code cancel (error_code& ec) { - return lowest_layer ().cancel (ec); + return m_next_layer.lowest_layer ().cancel (ec); } error_code shutdown (Socket::shutdown_type what, error_code& ec) { - return lowest_layer ().shutdown (what, ec); + return m_next_layer.lowest_layer ().shutdown (what, ec); } error_code close (error_code& ec) { - return lowest_layer ().close (ec); + return m_next_layer.lowest_layer ().close (ec); } //-------------------------------------------------------------------------- @@ -144,26 +179,14 @@ protected: void async_read_some (MutableBuffers const& buffers, SharedHandlerPtr handler) { -#if 0 - stream ().async_read_some (buffeers, + stream ().async_read_some (buffers, BOOST_ASIO_MOVE_CAST(SharedHandlerPtr)(handler)); -#else - stream ().async_read_some (buffers, m_strand.wrap ( - boost::bind (handler, boost::asio::placeholders::error, - boost::asio::placeholders::bytes_transferred))); -#endif } void async_write_some (ConstBuffers const& buffers, SharedHandlerPtr handler) { -#if 0 stream ().async_write_some (buffers, BOOST_ASIO_MOVE_CAST(SharedHandlerPtr)(handler)); -#else - stream ().async_write_some (buffers, m_strand.wrap ( - boost::bind ( handler, boost::asio::placeholders::error, - boost::asio::placeholders::bytes_transferred))); -#endif } //-------------------------------------------------------------------------- @@ -171,6 +194,19 @@ protected: // ssl::stream // + /** Retrieve the next_layer. + The next_layer_type is always the StreamSocket, regardless of what + we have put around it in terms of SSL or buffering. + */ + void* next_layer_ptr (char const* type_name) const + { + char const* const name (typeid (next_layer_type).name ()); + if (strcmp (name, type_name) == 0) + return const_cast (static_cast (&m_next_layer)); + return nullptr; + + } + /** Determine if the caller needs to call a handshaking function. This is also used to determine if the handshaking shutdown() has to be called. @@ -184,6 +220,16 @@ protected: //-------------------------------------------------------------------------- + void set_verify_mode (int verify_mode) + { + if (m_ssl_stream != nullptr) + m_ssl_stream->set_verify_mode (verify_mode); + else + m_verify_mode = verify_mode; + } + + //-------------------------------------------------------------------------- + error_code handshake (Socket::handshake_type type, error_code& ec) { return handshake (type, ConstBuffers (), ec); @@ -215,7 +261,7 @@ protected: ConstBuffers const& buffers, SharedHandlerPtr const& handler) { return get_io_service ().dispatch (SharedHandlerPtr ( - new AsyncOp (this, type, buffers, handler))); + new AsyncOp (*this, m_next_layer, type, buffers, handler))); } //-------------------------------------------------------------------------- @@ -252,7 +298,7 @@ protected: { // Our interface didn't require a shutdown but someone called // it anyone so generate an error code. - bassertfalse; + //bassertfalse; ec = handshake_error (); } @@ -285,29 +331,18 @@ protected: if (flags.set (Flag::ssl_required)) flags = flags.without (Flag::ssl); } - else - { - // if not client or server, clear out all the flags - flags = Flag::peer; - } return flags; } - bool is_server () const noexcept - { - return m_flags.set (Flag::server_role); - } - bool is_client () const noexcept { return m_flags.set (Flag::client_role); } - bool is_peer () const noexcept + bool is_server () const noexcept { - bassert (is_client () || is_server () || (m_flags == Flag (Flag::peer))); - return m_flags == Flag (Flag::peer); + return m_flags.set (Flag::server_role); } // Bottleneck to indicate a failed handshake. @@ -399,8 +434,23 @@ protected: } else { - m_state = stateReady; - m_stream = new_plain_stream (); + // We will determine client/server mode + // at the time handshake is called. + + bassert (! m_flags.set (Flag::proxy)); + m_flags = m_flags.without (Flag::proxy); + + if (m_flags.any_set (Flag::ssl | Flag::ssl_required)) + { + // We will decide stream type at handshake time + m_state = stateHandshake; + m_stream = nullptr; + } + else + { + m_state = stateReady; + m_stream = new_plain_stream (); + } } // We only set this to true in stateHandshake and @@ -441,9 +491,26 @@ protected: return handshake_error (); } - // Peer roles cannot handshake. - if (m_flags == Flag (Flag::peer)) - return handshake_error (); + // Set flags based on handshake if necessary + if (! m_flags.any_set (Flag::client_role | Flag::server_role)) + { + switch (type) + { + case client: + m_flags = m_flags.with (Flag::client_role); + break; + + case server: + m_flags = m_flags.with (Flag::server_role); + break; + + default: + fatal_error ("invalid tyoe"); + break; + } + + m_flags = cleaned_flags (m_flags); + } // Handshake type must match the role flags if ( (type == client && ! is_client ()) || @@ -502,11 +569,24 @@ protected: //-------------------------------------------------------------------------- + template + void set_ssl_stream (boost::asio::ssl::stream & ssl_stream) + { + m_ssl_stream = new SocketWrapper &> + (ssl_stream); + m_ssl_stream->set_verify_mode (m_verify_mode); + m_native_ssl_handle = ssl_stream.native_handle (); + } + + //-------------------------------------------------------------------------- + // Create a plain stream that just wraps the next layer. // Socket* new_plain_stream () { - return new SocketWrapper (m_next_layer); + typedef SocketWrapper Wrapper; + Wrapper* const socket (new Wrapper (m_next_layer)); + return socket; } // Create a plain stream but front-load it with some bytes. @@ -517,8 +597,10 @@ protected: { if (boost::asio::buffer_size (buffers) > 0) { - typedef PrefilledReadStream this_type; - return new SocketWrapper (m_next_layer, buffers); + typedef SocketWrapper > Wrapper; + Wrapper* const socket (new Wrapper (m_next_layer)); + socket->this_layer ().fill (buffers); + return socket; } return new_plain_stream (); } @@ -527,9 +609,16 @@ protected: // Socket* new_ssl_stream () { - typedef typename boost::asio::ssl::stream this_type; - return new SocketWrapper ( + typedef typename boost::asio::ssl::stream SslStream; +#if MULTISOCKET_USE_STRAND + typedef SocketWrapperStrand Wrapper; +#else + typedef SocketWrapper Wrapper; +#endif + Wrapper* const socket = new Wrapper ( m_next_layer, MultiSocket::getRippleTlsBoostContext ()); + set_ssl_stream (socket->this_layer ()); + return socket; } // Creates an ssl stream, but front-load it with some bytes. @@ -540,11 +629,17 @@ protected: { if (boost::asio::buffer_size (buffers) > 0) { - typedef PrefilledReadStream next_type; - typedef boost::asio::ssl::stream this_type; - SocketWrapper * const socket = new SocketWrapper ( + typedef boost::asio::ssl::stream < + PrefilledReadStream > SslStream; +#if MULTISOCKET_USE_STRAND + typedef SocketWrapperStrand Wrapper; +#else + typedef SocketWrapper Wrapper; +#endif + Wrapper* const socket = new Wrapper ( m_next_layer, MultiSocket::getRippleTlsBoostContext ()); - socket->this_layer().next_layer().fill (buffers); + socket->this_layer ().next_layer().fill (buffers); + set_ssl_stream (socket->this_layer ()); return socket; } return new_ssl_stream (); @@ -691,16 +786,18 @@ protected: // Composed asynchronous handshake operator // + template struct AsyncOp : ComposedAsyncOperation { typedef SharedHandlerAllocator Allocator; - AsyncOp (MultiSocketType * socket, + AsyncOp (MultiSocketType & owner, Stream& stream, handshake_type type, ConstBuffers const& buffers, SharedHandlerPtr const& handler) : ComposedAsyncOperation (sizeof (*this), handler) , m_handler (handler) - , m_socket (socket) + , m_owner (owner) + , m_stream (stream) , m_type (type) , m_buffer (std::numeric_limits ::max(), handler) , m_running (false) @@ -726,22 +823,22 @@ protected: { m_running = true; - error_code ec (m_socket->initHandshake (m_type)); + error_code ec (m_owner.initHandshake (m_type)); if (! ec) { - if (m_socket->m_state != stateReady) + if (m_owner.m_state != stateReady) { (*this) (ec); return; } // Always need shutdown if handshake successful. - m_socket->m_needsShutdown = true; + m_owner.m_needsShutdown = true; } // Call the original handler with the error code and end. - m_socket->get_io_service ().wrap ( + m_owner.get_io_service ().wrap ( BOOST_ASIO_MOVE_CAST (SharedHandlerPtr)(m_handler)) (ec); } @@ -756,22 +853,22 @@ protected: do { - if (m_socket->m_state == stateReady) + if (m_owner.m_state == stateReady) { // Always need shutdown if handshake successful. - m_socket->m_needsShutdown = true; + m_owner.m_needsShutdown = true; break; } - switch (m_socket->m_state) + switch (m_owner.m_state) { case stateHandshakeFinal: { // Have to set this beforehand even // though we might get an error. // - m_socket->m_state = stateReady; - m_socket->stream ().async_handshake (m_type, + m_owner.m_state = stateReady; + m_owner.stream ().async_handshake (m_type, SharedHandlerPtr (this)); } return; @@ -782,22 +879,22 @@ protected: { if (m_proxy.getLogic ().success ()) { - m_socket->setProxyInfo (m_proxy.getLogic ().getInfo ()); + m_owner.setProxyInfo (m_proxy.getLogic ().getInfo ()); // Strip off the PROXY flag. - m_socket->m_flags = m_socket->m_flags.without (Flag::proxy); + m_owner.m_flags = m_owner.m_flags.without (Flag::proxy); // Update handshake state with the leftover bytes. - ec = m_socket->initHandshake (m_type, m_buffer.data ()); + ec = m_owner.initHandshake (m_type, m_buffer.data ()); break; } // Didn't get the PROXY handshake we needed - ec = m_socket->handshake_error (); + ec = m_owner.handshake_error (); break; } - m_proxy.async_detect (m_socket->next_layer (), + m_proxy.async_detect (m_stream, m_buffer, SharedHandlerPtr (this)); } return; @@ -810,21 +907,21 @@ protected: if (m_ssl.getLogic ().success ()) { // Convert the ssl flag to ssl_required - m_socket->m_flags = m_socket->m_flags.with ( + m_owner.m_flags = m_owner.m_flags.with ( Flag::ssl_required).without (Flag::ssl); } else { // Not SSL, strip the ssl flag - m_socket->m_flags = m_socket->m_flags.without (Flag::ssl); + m_owner.m_flags = m_owner.m_flags.without (Flag::ssl); } // Update handshake state with the leftover bytes. - ec = m_socket->initHandshake (m_type, m_buffer.data ()); + ec = m_owner.initHandshake (m_type, m_buffer.data ()); break; } - m_ssl.async_detect (m_socket->next_layer (), + m_ssl.async_detect (m_stream, m_buffer, SharedHandlerPtr (this)); } return; @@ -834,15 +931,15 @@ protected: case stateHandshake: default: fatal_error ("invalid state"); - ec = m_socket->handshake_error (); + ec = m_owner.handshake_error (); } } while (! ec); - bassert (ec || (m_socket->m_state == stateReady && m_socket->m_needsShutdown)); + bassert (ec || (m_owner.m_state == stateReady && m_owner.m_needsShutdown)); // Call the original handler with the error code and end. - m_socket->get_io_service ().wrap ( + m_owner.get_io_service ().wrap ( BOOST_ASIO_MOVE_CAST (SharedHandlerPtr)(m_handler)) (ec); } @@ -854,7 +951,8 @@ protected: private: SharedHandlerPtr m_handler; - MultiSocketType * m_socket; + MultiSocketType & m_owner; + Stream& m_stream; handshake_type const m_type; boost::asio::basic_streambuf m_buffer; HandshakeDetectorType m_proxy; @@ -865,10 +963,14 @@ protected: private: Flag m_flags; State m_state; + int m_verify_mode; ScopedPointer m_stream; + ScopedPointer m_ssl_stream; // the ssl portion of our stream if it exists + bool m_needsShutdown; StreamSocket m_next_layer; - boost::asio::io_service::strand m_strand; + + SSL* m_native_ssl_handle; }; #endif