diff --git a/modules/ripple_asio/ripple_asio.cpp b/modules/ripple_asio/ripple_asio.cpp index 1a4ec5c73..128363f8c 100644 --- a/modules/ripple_asio/ripple_asio.cpp +++ b/modules/ripple_asio/ripple_asio.cpp @@ -26,6 +26,10 @@ namespace ripple #include "boost/ripple_IoService.cpp" #include "boost/ripple_SslContext.cpp" +# include "sockets/ripple_RippleTlsContext.h" +# include "sockets/ripple_MultiSocket.h" +#include "sockets/ripple_MultiSocketType.h" + #include "sockets/ripple_RippleTlsContext.cpp" #include "sockets/ripple_MultiSocket.cpp" diff --git a/modules/ripple_asio/ripple_asio.h b/modules/ripple_asio/ripple_asio.h index 24c595166..55918ebcb 100644 --- a/modules/ripple_asio/ripple_asio.h +++ b/modules/ripple_asio/ripple_asio.h @@ -31,10 +31,6 @@ using namespace beast; #include "boost/ripple_IoService.h" #include "boost/ripple_SslContext.h" -#include "sockets/ripple_MultiSocket.h" -# include "sockets/ripple_RippleTlsContext.h" -#include "sockets/ripple_MultiSocketType.h" - } #endif diff --git a/modules/ripple_asio/sockets/ripple_MultiSocket.cpp b/modules/ripple_asio/sockets/ripple_MultiSocket.cpp index 6f129e2ac..e8c5dbbdc 100644 --- a/modules/ripple_asio/sockets/ripple_MultiSocket.cpp +++ b/modules/ripple_asio/sockets/ripple_MultiSocket.cpp @@ -23,10 +23,16 @@ void MultiSocket::Options::setFromFlags (Flags flags) //------------------------------------------------------------------------------ -MultiSocket* MultiSocket::New (boost::asio::io_service& io_service, - Options const& options) +MultiSocket* MultiSocket::New (boost::asio::io_service& io_service, int flags) { - return new MultiSocketType (io_service, options); + return new MultiSocketType (io_service, flags); +} + +SslContextBase::BoostContextType &MultiSocket::getRippleTlsBoostContext () +{ + static ScopedPointer context (RippleTlsContext::New ()); + + return context->getBoostContext (); } //------------------------------------------------------------------------------ @@ -39,48 +45,35 @@ public: public: typedef int arg_type; - // These flags get combined to determine the multisocket attributes - // - enum Flags - { - none = 0, - client_ssl = 1, - server_ssl = 2, - server_ssl_required = 4, - server_proxy = 8, - - // these are for producing the endpoint test parameters - tcpv4 = 16, - tcpv6 = 32 - }; - - MultiSocketDetails (arg_type flags) + MultiSocketDetails (int flags) : m_flags (flags) { - m_socketOptions.useClientSsl = (flags & client_ssl) != 0; - m_socketOptions.enableServerSsl = (flags & (server_ssl | server_ssl_required)) != 0; - m_socketOptions.requireServerSsl = (flags & server_ssl_required) != 0; - m_socketOptions.requireServerProxy = (flags & server_proxy) !=0; } static String getArgName (arg_type arg) { String s; - if (arg & tcpv4) s << "tcpv4:"; - if (arg & tcpv6) s << "tcpv6:"; - if (arg != 0) + + if (arg & MultiSocket::Flag::client_role) + s << "client,"; + + if (arg & MultiSocket::Flag::server_role) + s << "server,"; + + if (arg & MultiSocket::Flag::ssl) + s << "ssl,"; + + if (arg & MultiSocket::Flag::ssl_required) + s << "ssl_required,"; + + if (arg & MultiSocket::Flag::proxy) + s << "proxy,"; + + if (s != String::empty) { - s << "["; - if (arg & client_ssl) s << "client_ssl,"; - if (arg & server_ssl) s << "server_ssl,"; - if (arg & server_ssl_required) s << "server_ssl_required,"; - if (arg & server_proxy) s << "server_proxy,"; - s = s.substring (0, s.length () - 1) + "]"; - } - else - { - s = "[plain]"; + s = "(" + s.substring (0, s.length () - 1) + ")"; } + return s; } @@ -94,23 +87,17 @@ public: return m_flags; } - MultiSocket::Options const& getSocketOptions () const noexcept - { - return m_socketOptions; - } - protected: arg_type m_flags; - MultiSocket::Options m_socketOptions; }; //-------------------------------------------------------------------------- - template + template class MultiSocketDetailsType : public MultiSocketDetails { protected: - typedef InternetProtocol protocol_type; + typedef Protocol protocol_type; typedef typename protocol_type::socket socket_type; typedef typename protocol_type::acceptor acceptor_type; typedef typename protocol_type::endpoint endpoint_type; @@ -120,11 +107,11 @@ public: typedef socket_type native_socket_type; typedef acceptor_type native_acceptor_type; - explicit MultiSocketDetailsType (arg_type flags = none) + explicit MultiSocketDetailsType (arg_type flags) : MultiSocketDetails (flags) , m_socket (get_io_service ()) , m_acceptor (get_io_service ()) - , m_multiSocket (m_socket, getSocketOptions ()) + , m_multiSocket (m_socket, flags) , m_acceptor_wrapper (m_acceptor) { } @@ -151,85 +138,90 @@ public: endpoint_type get_endpoint (PeerRole role) { - if (getFlags () & MultiSocketDetails::tcpv6) - { - if (role == PeerRole::server) - return endpoint_type (boost::asio::ip::tcp::v6 (), 1052); - else - return endpoint_type (boost::asio::ip::address_v6 ().from_string ("::1"), 1052); - } + if (role == PeerRole::server) + return endpoint_type (boost::asio::ip::tcp::v6 (), 1052); else - { - if (role == PeerRole::server) - return endpoint_type (boost::asio::ip::address_v4::any (), 1053); - else - return endpoint_type (boost::asio::ip::address_v4::loopback (), 1053); - } + return endpoint_type (boost::asio::ip::address_v6 ().from_string ("::1"), 1052); } protected: socket_type m_socket; acceptor_type m_acceptor; MultiSocketType m_multiSocket; - SocketWrapper m_acceptor_wrapper; + SocketWrapper m_acceptor_wrapper; }; - MultiSocketTests () : UnitTest ("MultiSocket", "ripple", runManual) + //-------------------------------------------------------------------------- + + template + void run_async (ClientArg const& clientArg, ServerArg const& serverArg) { + PeerTest::run , + TestPeerLogicAsyncClient, TestPeerLogicAsyncServer> (clientArg, serverArg, timeoutSeconds).report (*this); + } + + template + void run (ClientArg const& clientArg, ServerArg const& serverArg) + { + PeerTest::run , + TestPeerLogicSyncClient, TestPeerLogicSyncServer> (clientArg, serverArg, timeoutSeconds).report (*this); + + PeerTest::run , + TestPeerLogicAsyncClient, TestPeerLogicSyncServer> (clientArg, serverArg, timeoutSeconds).report (*this); + + PeerTest::run , + TestPeerLogicSyncClient, TestPeerLogicAsyncServer> (clientArg, serverArg, timeoutSeconds).report (*this); + + run_async (clientArg, serverArg); + } + + //-------------------------------------------------------------------------- + + template + void testFlags (int extraClientFlags, int extraServerFlags) + { + check_precondition (! MultiSocket::Flag (extraClientFlags).any_set (MultiSocket::Flag::client_role | MultiSocket::Flag::server_role)); + +#if 1 + run (MultiSocket::Flag::client_role | extraClientFlags, MultiSocket::Flag::server_role | extraServerFlags); +#else + run_async (MultiSocket::Flag::client_role | extraClientFlags, MultiSocket::Flag::server_role | extraServerFlags); +#endif + } + + template + void testProtocol () + { +#if 1 + // These should pass. + run (0, 0); + run (MultiSocket::Flag::client_role, 0); + run (0, MultiSocket::Flag::server_role); + run (MultiSocket::Flag::client_role, MultiSocket::Flag::server_role); +#endif + +#if 0 + // These should pass + testFlags (MultiSocket::Flag::ssl, MultiSocket::Flag::ssl_required); + testFlags (0, MultiSocket::Flag::ssl); + testFlags (MultiSocket::Flag::ssl, MultiSocket::Flag::ssl); +#endif + } + + void runTest () + { + testProtocol (); } //-------------------------------------------------------------------------- enum { - timeoutSeconds = 1 + timeoutSeconds = 10 }; - template - PeerTest::Results runProtocol (Arg const& arg) + MultiSocketTests () : UnitTest ("MultiSocket", "ripple", runManual) { - return PeerTest::run , - TestPeerLogicAsyncServer, TestPeerLogicAsyncClient> (arg, timeoutSeconds); - } - - // Analyzes the results of the test based on the flags - // - void reportResults (int flags, PeerTest::Results const& results) - { - if ( (flags & MultiSocketDetails::client_ssl) != 0) - { - if ( ((flags & MultiSocketDetails::server_ssl) == 0) && - ((flags & MultiSocketDetails::server_ssl_required) == 0)) - { - } - } - else - { - } - - results.report (*this); - } - - void testOptions (int flags) - { - PeerTest::Results const results = runProtocol (flags); - - results.report (*this); - } - - //-------------------------------------------------------------------------- - - void runTest () - { - // These should pass - testOptions (MultiSocketDetails::none); - testOptions (MultiSocketDetails::server_ssl); - testOptions (MultiSocketDetails::client_ssl | MultiSocketDetails::server_ssl); - testOptions (MultiSocketDetails::client_ssl | MultiSocketDetails::server_ssl_required); - - // These should fail - testOptions (MultiSocketDetails::client_ssl); - testOptions (MultiSocketDetails::server_ssl_required); } }; diff --git a/modules/ripple_asio/sockets/ripple_MultiSocket.h b/modules/ripple_asio/sockets/ripple_MultiSocket.h index 4e3634e37..7991cb9f3 100644 --- a/modules/ripple_asio/sockets/ripple_MultiSocket.h +++ b/modules/ripple_asio/sockets/ripple_MultiSocket.h @@ -12,6 +12,59 @@ class MultiSocket : public Socket { public: + // immutable flags + struct Flag + { + enum Bit + { + peer = 0, // no handshaking. remaining flags ignored. + client_role = 1, // operate in client role + server_role = 2, // operate in server role + + proxy = 4, // client: will send PROXY handshake + // server: PROXY handshake required + + ssl = 8, // client: will use ssl + // server: will allow, but not require ssl + + ssl_required = 16 // client: ignored + // server: will require ssl (ignores ssl flag) + }; + + Flag (int flags) noexcept + : m_flags (flags) + { + } + + bool operator== (Flag const& other) const noexcept + { + return m_flags == other.m_flags; + } + + inline bool set (int mask) const noexcept + { + return (m_flags & mask) == mask; + } + + inline bool any_set (int mask) const noexcept + { + return (m_flags & mask) != 0; + } + + Flag with (int mask) const noexcept + { + return Flag (m_flags | mask); + } + + Flag without (int mask) const noexcept + { + return Flag (m_flags & ~mask); + } + + private: + int m_flags; + }; + enum Flags { none = 0, @@ -42,8 +95,12 @@ public: void setFromFlags (Flags flags); }; - static MultiSocket* New (boost::asio::io_service& io_service, - Options const& options = none); + static MultiSocket* New (boost::asio::io_service& io_service, int flags = 0); + + // Ripple uses a SSL/TLS context with specific parameters and this returns + // a reference to the corresponding boost::asio::ssl::context object. + // + static SslContextBase::BoostContextType &getRippleTlsBoostContext (); }; #endif diff --git a/modules/ripple_asio/sockets/ripple_MultiSocketType.h b/modules/ripple_asio/sockets/ripple_MultiSocketType.h index 2f37bec6f..12347d253 100644 --- a/modules/ripple_asio/sockets/ripple_MultiSocketType.h +++ b/modules/ripple_asio/sockets/ripple_MultiSocketType.h @@ -7,11 +7,103 @@ #ifndef RIPPLE_MULTISOCKETTYPE_H_INCLUDED #define RIPPLE_MULTISOCKETTYPE_H_INCLUDED -/** Template for producing instances of MultiSocket +//------------------------------------------------------------------------------ - StreamSocket must meet these requirements: +template +class ssl_detector + : public boost::asio::ssl::stream_base + , public boost::asio::socket_base +{ +private: + typedef typename boost::remove_reference ::type stream_type; + +public: +#if 0 + struct SocketInterfaces + : SocketInterface::Handshake + { }; +#endif + + template + ssl_detector (Arg& arg) + : m_next_layer (arg) + , m_stream (m_next_layer) + { + } + + // basic_io_object + + boost::asio::io_service& get_io_service () + { + return m_next_layer.get_io_service (); + } - SocketInterface::Socket, SocketInterface::Stream + // basic_socket + + typedef typename stream_type::protocol_type protocol_type; + typedef typename stream_type::lowest_layer_type lowest_layer_type; + + lowest_layer_type& lowest_layer () + { + return m_next_layer.lowest_layer (); + } + + lowest_layer_type const& lowest_layer () const + { + return m_next_layer.lowest_layer (); + } + + // ssl::stream + + boost::system::error_code handshake (handshake_type type, boost::system::error_code& ec) + { + ec = boost::system::error_code (); + + m_buffer.commit (boost::asio::read (m_next_layer, + m_buffer.prepare (m_detector.needed ()), ec)); + + if (!ec) + { + bool const finished = m_detector.analyze (m_buffer.data ()); + + if (finished) + { + } + } + + return ec; + } + + template + BOOST_ASIO_INITFN_RESULT_TYPE(HandshakeHandler, void (boost::system::error_code)) + async_handshake(handshake_type type, BOOST_ASIO_MOVE_ARG(HandshakeHandler) handler) + { + } + +#if BEAST_ASIO_HAS_BUFFEREDHANDSHAKE + template + void handshake (handshake_type type, const ConstBufferSequence& buffers) + { + } + + template + BOOST_ASIO_INITFN_RESULT_TYPE(BufferedHandshakeHandler, void (boost::system::error_code, std::size_t)) + async_handshake(handshake_type type, const ConstBufferSequence& buffers, + BOOST_ASIO_MOVE_ARG(BufferedHandshakeHandler) handler) + { + } +#endif + +private: + Stream m_next_layer; + boost::asio::streambuf m_buffer; + boost::asio::buffered_read_stream m_stream; + HandshakeDetectorType m_detector; +}; + +//------------------------------------------------------------------------------ + +/** Template for producing instances of MultiSocket */ template class MultiSocketType @@ -24,32 +116,19 @@ public: typedef typename next_layer_type::lowest_layer_type lowest_layer_type; typedef typename boost::asio::ssl::stream ssl_stream_type; - typedef MultiSocketType ThisType; - - enum Status - { - needMore, - proxy, - plain, - ssl - }; - template - explicit MultiSocketType (Arg& arg, Options options = Options()) + explicit MultiSocketType (Arg& arg, int flags = 0, Options const& options = Options ()) : m_options (options) - , m_context (RippleTlsContext::New ()) + , m_flags (cleaned_flags (flags)) , m_next_layer (arg) - , m_io_service (m_next_layer.get_io_service ()) - , m_strand (m_io_service) - , m_status (needMore) - , m_role (Socket::client) - + , m_stream (new_current_stream (true)) + , m_strand (get_io_service ()) { } Socket& stream () const noexcept { - fatal_assert (m_stream != nullptr); + bassert (m_stream != nullptr); return *m_stream; } @@ -75,77 +154,34 @@ public: //-------------------------------------------------------------------------- // - // General + // basic_io_object // - //-------------------------------------------------------------------------- boost::asio::io_service& get_io_service () { - return lowest_layer ().get_io_service (); - } - - bool requires_handshake () - { - return true; - } - - void* this_layer_raw (char const* type_name) const - { - // 'this' layer will be the underlying StreamSocket since - // we support all of its functionality. -#if 1 - char const* const next_layer_type_name (typeid (next_layer_type).name ()); - if (strcmp (type_name, next_layer_type_name) == 0) - return const_cast (static_cast (&next_layer ())); - return nullptr; -#else - if (m_stream != nullptr) - { - if (strcmp (type_name, typeid (lowest_layer_type).name ()) == 0) - return const_cast (static_cast (&lowest_layer ())); - } - pure_virtual (); - return nullptr; -#endif + return m_next_layer.get_io_service (); } //-------------------------------------------------------------------------- // - // SocketInterface::Close + // basic_socket // - //-------------------------------------------------------------------------- - boost::system::error_code close (boost::system::error_code& ec) + void* lowest_layer (char const* type_name) const { - return lowest_layer ().close (ec); - } - - //-------------------------------------------------------------------------- - // - // SocketInterface::Acceptor - // - //-------------------------------------------------------------------------- - - // nothing yet - - //-------------------------------------------------------------------------- - // - // SocketInterface::LowestLayer - // - //-------------------------------------------------------------------------- - - void* lowest_layer_raw (char const* type_name) const - { - if (strcmp (type_name, typeid (lowest_layer_type).name ()) == 0) + char const* const name (typeid (lowest_layer_type).name ()); + if (strcmp (name, type_name) == 0) return const_cast (static_cast (&lowest_layer ())); return nullptr; } - //-------------------------------------------------------------------------- - // - // SocketInterface::Socket - // - //-------------------------------------------------------------------------- + void* native_handle (char const* type_name) const + { + char const* const name (typeid (next_layer_type).name ()); + if (strcmp (name, type_name) == 0) + return const_cast (static_cast (&next_layer ())); + return nullptr; + } boost::system::error_code cancel (boost::system::error_code& ec) { @@ -157,21 +193,18 @@ public: return lowest_layer ().shutdown (what, ec); } + boost::system::error_code close (boost::system::error_code& ec) + { + return lowest_layer ().close (ec); + } + //-------------------------------------------------------------------------- // - // SocketInterface::Stream + // basic_stream_socket // - //-------------------------------------------------------------------------- std::size_t read_some (MutableBuffers const& buffers, boost::system::error_code& ec) { - if (m_buffer.size () > 0) - { - ec = boost::system::error_code (); - std::size_t const amount = boost::asio::buffer_copy (buffers, m_buffer.data ()); - m_buffer.consume (amount); - return amount; - } return stream ().read_some (buffers, ec); } @@ -183,65 +216,33 @@ public: BEAST_ASIO_INITFN_RESULT_TYPE_MEMBER(TransferCall, void (boost::system::error_code, std::size_t)) async_read_some (MutableBuffers const& buffers, BOOST_ASIO_MOVE_ARG(TransferCall) handler) { - if (m_buffer.size () > 0) - { - // Return the leftover bytes from the handshake - std::size_t const amount = boost::asio::buffer_copy (buffers, m_buffer.data ()); - m_buffer.consume (amount); - return m_io_service.post (m_strand.wrap (boost::bind ( - BOOST_ASIO_MOVE_CAST(TransferCall)(handler), boost::system::error_code (), amount))); - } return stream ().async_read_some (buffers, m_strand.wrap (boost::bind ( BOOST_ASIO_MOVE_CAST(TransferCall)(handler), boost::asio::placeholders::error, - boost::asio::placeholders::bytes_transferred))); + boost::asio::placeholders::bytes_transferred))); } BEAST_ASIO_INITFN_RESULT_TYPE_MEMBER(TransferCall, void (boost::system::error_code, std::size_t)) async_write_some (ConstBuffers const& buffers, BOOST_ASIO_MOVE_ARG(TransferCall) handler) { - return stream ().async_write_some (buffers, - BOOST_ASIO_MOVE_CAST(TransferCall)(handler)); + return stream ().async_write_some (buffers, m_strand.wrap (boost::bind ( + BOOST_ASIO_MOVE_CAST(TransferCall)(handler), boost::asio::placeholders::error, + boost::asio::placeholders::bytes_transferred))); } //-------------------------------------------------------------------------- // - // SocketInterface::Handshake + // ssl::stream // - //-------------------------------------------------------------------------- - boost::system::error_code handshake (Socket::handshake_type role, boost::system::error_code& ec) + bool requires_handshake () { - Action action = calcAction (role); + return stream ().requires_handshake (); + } - switch (action) - { - default: - case actionPlain: - handshakePlain (ec); - break; - - case actionSsl: - handshakeSsl (ec); - break; - - case actionDetect: - detectHandshake (ec); - if (! ec) - { - action = calcDetectAction (ec); - switch (action) - { - default: - case actionPlain: - handshakePlain (ec); - break; - case actionSsl: - handshakeSsl (ec); - break; - }; - } - break; - } + boost::system::error_code handshake (Socket::handshake_type type, boost::system::error_code& ec) + { + if (! prepare_handshake (type, ec)) + return stream ().handshake (type, ec); return ec; } @@ -249,57 +250,22 @@ public: BEAST_ASIO_INITFN_RESULT_TYPE_MEMBER(ErrorCall, void (boost::system::error_code)) async_handshake (handshake_type type, BOOST_ASIO_MOVE_ARG(ErrorCall) handler) { - Action const action = calcAction (type); - switch (action) - { - default: - case actionPlain: - return handshakePlainAsync (BOOST_ASIO_MOVE_CAST(ErrorCall)(handler)); - break; + boost::system::error_code ec; - case actionSsl: - return handshakeSslAsync (BOOST_ASIO_MOVE_CAST(ErrorCall)(handler)); - break; + if (! prepare_handshake (type, ec)) + return stream ().async_handshake (type, BOOST_ASIO_MOVE_CAST(ErrorCall)(handler)); - case actionDetect: - return detectHandshakeAsync (BOOST_ASIO_MOVE_CAST(ErrorCall)(handler)); - break; - } + return get_io_service ().post (boost::bind ( + BOOST_ASIO_MOVE_CAST(ErrorCall)(handler), ec)); } #if BEAST_ASIO_HAS_BUFFEREDHANDSHAKE boost::system::error_code handshake (handshake_type type, ConstBuffers const& buffers, boost::system::error_code& ec) { - Action action = calcAction (type); - ec = boost::system::error_code (); - switch (action) - { - default: - case actionPlain: - handshakePlain (buffers, ec); - break; - case actionSsl: - handshakeSsl (buffers, ec); - break; - case actionDetect: - detectHandshake (buffers, ec); - if (! ec) - { - action = calcDetectAction (ec); - switch (action) - { - default: - case actionPlain: - handshakePlain (buffers, ec); - break; - case actionSsl: - handshakeSsl (buffers, ec); - break; - }; - } - break; - } + if (! prepare_handshake (type, ec)) + return stream ().handshake (type, buffers, ec); + return ec; } @@ -307,402 +273,202 @@ public: async_handshake (handshake_type type, ConstBuffers const& buffers, BOOST_ASIO_MOVE_ARG(TransferCall) handler) { - Action const action = calcAction (type); - switch (action) - { - default: - case actionPlain: - return handshakePlainAsync (buffers, - BOOST_ASIO_MOVE_CAST(TransferCall)(handler)); - break; + boost::system::error_code ec; - case actionSsl: - return handshakeSslAsync (buffers, + if (! prepare_handshake (type, ec)) + return stream ().async_handshake (type, buffers, BOOST_ASIO_MOVE_CAST(TransferCall)(handler)); - break; - case actionDetect: - return detectHandshakeAsync (buffers, - BOOST_ASIO_MOVE_CAST(TransferCall)(handler)); - break; - } + return get_io_service ().post (boost::bind ( + BOOST_ASIO_MOVE_CAST(TransferCall)(handler), ec, 0)); } #endif boost::system::error_code shutdown (boost::system::error_code& ec) { - if (m_status == ssl) - { - return m_ssl_stream->shutdown (ec); - } - else - { - // we need to close the lwest layer - return m_next_layer.shutdown (next_layer_type::shutdown_both, ec); - } + if (stream ().requires_handshake ()) + return stream ().shutdown (ec); + + return handshake_error (ec); } BEAST_ASIO_INITFN_RESULT_TYPE_MEMBER(ErrorCall, void (boost::system::error_code)) async_shutdown (BOOST_ASIO_MOVE_ARG(ErrorCall) handler) { - if (m_status == ssl) + if (stream ().requires_handshake ()) + return stream ().async_shutdown (BOOST_ASIO_MOVE_CAST(ErrorCall)(handler)); + + boost::system::error_code ec; + handshake_error (ec); + return get_io_service ().post (boost::bind ( + BOOST_ASIO_MOVE_CAST(ErrorCall)(handler), ec)); + } + + //-------------------------------------------------------------------------- + // + // Handshake implementation details + // + //-------------------------------------------------------------------------- + + // Checks flags for preconditions and returns a cleaned-up version + // + static Flag cleaned_flags (Flag flags) + { + // Can't set both client and server + check_precondition (! flags.set (Flag::client_role | Flag::server_role)); + + if (flags.set (Flag::client_role)) { - m_ssl_stream->async_shutdown (m_strand.wrap (boost::bind ( - BOOST_ASIO_MOVE_CAST(ErrorCall)(handler), - boost::asio::placeholders::error))); + // clients ignore ssl_required + flags = flags.without (Flag::ssl_required); + } + else if (flags.set (Flag::server_role)) + { + // servers ignore ssl when ssl_required is set + if (flags.set (Flag::ssl_required)) + flags = flags.without (Flag::ssl); } else { - boost::system::error_code ec; - m_next_layer.shutdown (next_layer_type::shutdown_both, ec); - m_io_service.post (m_strand.wrap (boost::bind ( - BOOST_ASIO_MOVE_CAST(ErrorCall)(handler), ec))); + // if not client or server, clear out all the flags + flags = Flag::peer; } + + return flags; } //-------------------------------------------------------------------------- - enum Action + Socket* new_plain_stream () { - actionDetect, - actionPlain, - actionSsl, - actionFail - }; - - // Determines what action to take based on - // the stream options and the desired role. - // - Action calcAction (Socket::handshake_type role) - { - m_role = role; - - if (role == Socket::server) - { - if (! m_options.enableServerSsl && - ! m_options.requireServerSsl && - ! m_options.requireServerProxy) - { - return actionPlain; - } - else if (m_options.requireServerSsl && ! m_options.requireServerProxy) - { - return actionSsl; - } - else - { - return actionDetect; - } - } - else if (m_role == Socket::client) - { - if (m_options.useClientSsl) - { - return actionSsl; - } - else - { - return actionPlain; - } - } - - return actionPlain; + return new SocketWrapper (m_next_layer); } - // Determines what action to take based on the auto-detected - // handshake, the stream options, and desired role. - // - Action calcDetectAction (boost::system::error_code& ec) + Socket* new_ssl_stream () { - ec = boost::system::error_code (); + typedef typename boost::asio::ssl::stream ssl_stream_type; - if (m_status == plain) + return new SocketWrapper ( + m_next_layer, MultiSocket::getRippleTlsBoostContext ()); + } + + Socket* new_detect_ssl_stream () + { + return new SocketWrapper > (m_next_layer); + //return nullptr; + } + + Socket* new_proxy_stream () + { + return nullptr; + } + + // Creates a Socket representing what the current stream should + // be based on the flags and the state of the handshaking engine. + // + Socket* new_current_stream (bool createPlain) + { + Socket* result = nullptr; + + if (m_flags.set (Flag::client_role)) { - if (! m_options.requireServerProxy && ! m_options.requireServerSsl) + if (m_flags.set (Flag::proxy)) { - return actionPlain; + // unsupported function + SocketBase::pure_virtual (); + } + + if (m_flags.set (Flag::ssl)) + { + result = new_ssl_stream (); } else { - failedHandshake (ec); - return actionFail; + result = nullptr; } } - else if (m_status == ssl) + else { - if (! m_options.requireServerProxy) + if (m_flags.set (Flag::proxy)) { - if (m_options.enableServerSsl || m_options.requireServerSsl) - { - return actionSsl; - } - else - { - failedHandshake (ec); - return actionFail; - } + result = new_proxy_stream (); + } + else if (m_flags.set (Flag::ssl_required)) + { + result = new_ssl_stream (); + } + else if (m_flags.set (Flag::ssl)) + { + result = new_detect_ssl_stream (); } else { - failedHandshake (ec); - return actionFail; - } - } - else if (m_status == proxy) - { - if (m_options.requireServerProxy) - { - // read the rest of the proxy string - // then transition to SSL handshake mode - failedHandshake (ec); - return actionFail; - } - else - { - // Can we make PROXY optional? - failedHandshake (ec); - return actionFail; + result = nullptr; } } - failedHandshake (ec); - return actionFail; + if (createPlain) + { + if (result == nullptr) + result = new_plain_stream (); + } + + return result; } //-------------------------------------------------------------------------- - // called when options disallow handshake - void failedHandshake (boost::system::error_code& ec) + // Bottleneck to indicate a failed handshake. + // + boost::system::error_code handshake_error (boost::system::error_code& ec) { // VFALCO TODO maybe use a ripple error category? // set this to something custom that we can recognize later? - ec = boost::asio::error::invalid_argument; + // + return ec = boost::asio::error::invalid_argument; } - void createPlainStreamSocket () - { - m_status = plain; - m_stream = new SocketWrapper (m_next_layer); - } - - void handshakePlain (boost::system::error_code& ec) + // Sanity checks and cleans up our flags based on the desired handshake + // type and correctly sets the current stream based on the settings. + // If the requested handshake type is incompatible with the socket flags + // set on construction, an error is returned. + // + boost::system::error_code prepare_handshake (handshake_type type, + boost::system::error_code& ec) { ec = boost::system::error_code (); - createPlainStreamSocket (); - } - void handshakePlain (ConstBuffers const& buffers, boost::system::error_code& ec) - { - fatal_assert (boost::asio::buffer_size (buffers) == 0 ); - ec = boost::system::error_code (); - createPlainStreamSocket (); - } + // If we constructed as 'peer' then become a client + // or a server as needed according to the handshake type. + // + if (m_flags == Flag (Flag::peer)) + m_flags = (type == server) ? + Flag (Flag::server_role) : Flag (Flag::client_role); - BEAST_ASIO_INITFN_RESULT_TYPE_MEMBER(ErrorCall, void (boost::system::error_code)) - handshakePlainAsync (BOOST_ASIO_MOVE_ARG(ErrorCall) handler) - { - createPlainStreamSocket (); - return m_io_service.post (m_strand.wrap (boost::bind ( - BOOST_ASIO_MOVE_CAST(ErrorCall)(handler), boost::system::error_code()))); - } + // If we constructed with a specific role, the handshake + // must match the role. + // + if ( (type == client && ! m_flags.set (Flag::client_role)) || + (type == server && ! m_flags.set (Flag::server_role))) + return handshake_error (ec); - void createSslStreamSocket () - { - m_status = ssl; - m_ssl_stream = new ssl_stream_type (m_next_layer, m_context->getBoostContext ()); - m_stream = new SocketWrapper (*m_ssl_stream); - } + // Figure out what type of new stream we need. + // We pass false to avoid creating a redundant plain stream + Socket* socket = new_current_stream (false); - void handshakeSsl (boost::system::error_code& ec) - { - createSslStreamSocket (); - m_ssl_stream->handshake (m_role, ec); - } + if (socket != nullptr) + m_stream = socket; - BEAST_ASIO_INITFN_RESULT_TYPE_MEMBER(ErrorCall, void (boost::system::error_code)) - handshakeSslAsync (BOOST_ASIO_MOVE_ARG(ErrorCall) handler) - { - createSslStreamSocket (); - return m_ssl_stream->async_handshake (m_role, - BOOST_ASIO_MOVE_CAST(ErrorCall)(handler)); - } - -#if BEAST_ASIO_HAS_BUFFEREDHANDSHAKE - BEAST_ASIO_INITFN_RESULT_TYPE_MEMBER(TransferCall, void (boost::system::error_code, std::size_t)) - handshakePlainAsync (ConstBuffers const& buffers, - BOOST_ASIO_MOVE_ARG (TransferCall) handler) - { - fatal_assert (boost::asio::buffer_size (buffers) == 0); - createPlainStreamSocket (); - return m_io_service.post (m_strand.wrap (boost::bind ( - BOOST_ASIO_MOVE_CAST(TransferCall)(handler), - boost::system::error_code(), 0))); - } - - void handshakeSsl (ConstBuffers const& buffers, boost::system::error_code& ec) - { - createSslStreamSocket (); - m_ssl_stream->handshake (m_role, buffers, ec); - } - - BEAST_ASIO_INITFN_RESULT_TYPE_MEMBER(TransferCall, void (boost::system::error_code, std::size_t)) - handshakeSslAsync (ConstBuffers const& buffers, BOOST_ASIO_MOVE_ARG(TransferCall) handler) - { - createSslStreamSocket (); - return m_ssl_stream->async_handshake (m_role, buffers, - BOOST_ASIO_MOVE_CAST(TransferCall)(handler)); - } -#endif - - //-------------------------------------------------------------------------- - - enum - { - autoDetectBytes = 5 - }; - - void detectHandshake (boost::system::error_code& ec) - { - // Top up our buffer - bassert (m_buffer.size () == 0); - std::size_t const needed = autoDetectBytes; - std::size_t const amount = m_next_layer.receive ( - m_buffer.prepare (needed), boost::asio::socket_base::message_peek, ec); - m_buffer.commit (amount); - if (! ec) - { - analyzeHandshake (m_buffer.data ()); - m_buffer.consume (amount); - if (m_status == needMore) - ec = boost::asio::error::invalid_argument; // should never happen - } - } - -#if BEAST_ASIO_HAS_BUFFEREDHANDSHAKE - void detectHandshake (ConstBuffers const& buffers, boost::system::error_code& ec) - { - m_buffer.commit (boost::asio::buffer_copy ( - m_buffer.prepare (boost::asio::buffer_size (buffers)), buffers)); - detectHandshake (ec); - } -#endif - - //-------------------------------------------------------------------------- - - void onDetectRead (ErrorCall handler, - boost::system::error_code const& ec, std::size_t bytes_transferred) - { - m_buffer.commit (bytes_transferred); - - if (! ec) - { - analyzeHandshake (m_buffer.data ()); - - boost::system::error_code ec; - - if (m_status != needMore) - { - m_buffer.consume (bytes_transferred); - - Action action = calcDetectAction (ec); - if (! ec) - { - switch (action) - { - default: - case actionPlain: - handshakePlainAsync ( - BOOST_ASIO_MOVE_CAST(ErrorCall)(handler)); - break; - case actionSsl: - handshakeSslAsync ( - BOOST_ASIO_MOVE_CAST(ErrorCall)(handler)); - break; - }; - } - } - else - { - ec = boost::asio::error::invalid_argument; - } - - if (ec) - { - m_io_service.post (m_strand.wrap (boost::bind ( - BOOST_ASIO_MOVE_CAST(ErrorCall)(handler), ec))); - } - } - } - - BEAST_ASIO_INITFN_RESULT_TYPE_MEMBER(ErrorCall, void (boost::system::error_code)) - detectHandshakeAsync (BOOST_ASIO_MOVE_ARG(ErrorCall) handler) - { - bassert (m_buffer.size () == 0); - return m_next_layer.async_receive ( - m_buffer.prepare (autoDetectBytes), boost::asio::socket_base::message_peek, - m_strand.wrap (boost::bind (&ThisType::onDetectRead, this, handler, - boost::asio::placeholders::error, 0))); - } - -#if BEAST_ASIO_HAS_BUFFEREDHANDSHAKE - BEAST_ASIO_INITFN_RESULT_TYPE_MEMBER(TransferCall, void (boost::system::error_code, std::size_t)) - detectHandshakeAsync (ConstBuffers const& buffers, BOOST_ASIO_MOVE_ARG(TransferCall) handler) - { - fatal_error ("unimplemented"); - } -#endif - - //-------------------------------------------------------------------------- - - static inline bool isPrintable (unsigned char c) - { - return (c < 127) && (c > 31); - } - - template - void analyzeHandshake (ConstBufferSequence const& buffers) - { - m_status = needMore; - - unsigned char data [5]; - - std::size_t const bytes = boost::asio::buffer_copy (boost::asio::buffer (data), buffers); - - if (bytes > 0) - { - if ( isPrintable (data [0]) && - ((bytes < 2) || isPrintable (data [1])) && - ((bytes < 3) || isPrintable (data [2])) && - ((bytes < 4) || isPrintable (data [3])) && - ((bytes < 5) || isPrintable (data [4]))) - { - if (bytes < 5 || memcmp (data, "PROXY", 5) != 0) - { - m_status = plain; - } - else - { - m_status = proxy; - } - } - else - { - m_status = ssl; - } - } + return ec; } private: - Options m_options; - ScopedPointer m_context; + Options m_options; // DEPRECATED + + Flag m_flags; StreamSocket m_next_layer; - boost::asio::io_service& m_io_service; - boost::asio::io_service::strand m_strand; - Status m_status; - Socket::handshake_type m_role; ScopedPointer m_stream; - ScopedPointer m_ssl_stream; - boost::asio::streambuf m_buffer; + boost::asio::io_service::strand m_strand; }; #endif diff --git a/modules/ripple_websocket/autosocket/ripple_AutoSocket.h b/modules/ripple_websocket/autosocket/ripple_AutoSocket.h index ada35731a..200a9df99 100644 --- a/modules/ripple_websocket/autosocket/ripple_AutoSocket.h +++ b/modules/ripple_websocket/autosocket/ripple_AutoSocket.h @@ -24,10 +24,12 @@ public: typedef boost::function callback; public: +#if 0 struct SocketInterfaces : beast::SocketInterface::AsyncStream , beast::SocketInterface::AsyncHandshake , beast::SocketInterface::LowestLayer { }; +#endif AutoSocket (boost::asio::io_service& s, boost::asio::ssl::context& c) : mSecure (false)