Update MultiSocket for improved Socket

This commit is contained in:
Vinnie Falco
2013-08-24 18:00:08 -07:00
parent db4a1dfaa4
commit 7b9a5e8753
3 changed files with 211 additions and 85 deletions

View File

@@ -9,6 +9,20 @@ MultiSocket* MultiSocket::New (boost::asio::io_service& io_service, int flags)
return new MultiSocketType <boost::asio::ip::tcp::socket> (io_service, flags); return new MultiSocketType <boost::asio::ip::tcp::socket> (io_service, flags);
} }
MultiSocket* MultiSocket::New (boost::asio::ip::tcp::socket& socket, int flags)
{
return new MultiSocketType <boost::asio::ip::tcp::socket&> (socket, flags);
}
MultiSocket* MultiSocket::New (
boost::asio::ssl::stream <
boost::asio::ip::tcp::socket&>& stream, int flags)
{
return new MultiSocketType <
boost::asio::ssl::stream <
boost::asio::ip::tcp::socket&>& > (stream, flags);
}
struct RippleTlsContextHolder struct RippleTlsContextHolder
{ {
RippleTlsContextHolder () RippleTlsContextHolder ()
@@ -181,7 +195,6 @@ public:
TestPeerLogicSyncClient, TestPeerLogicAsyncServer> TestPeerLogicSyncClient, TestPeerLogicAsyncServer>
(clientArg, serverArg, timeoutSeconds).report (*this); (clientArg, serverArg, timeoutSeconds).report (*this);
PeerTest::run <MultiSocketDetailsType <Protocol>, PeerTest::run <MultiSocketDetailsType <Protocol>,
TestPeerLogicAsyncClient, TestPeerLogicAsyncServer> TestPeerLogicAsyncClient, TestPeerLogicAsyncServer>
(clientArg, serverArg, timeoutSeconds).report (*this); (clientArg, serverArg, timeoutSeconds).report (*this);
@@ -226,7 +239,6 @@ public:
testFlags <Protocol> (MultiSocket::Flag::ssl, testFlags <Protocol> (MultiSocket::Flag::ssl,
MultiSocket::Flag::ssl_required); MultiSocket::Flag::ssl_required);
// SSL-Detect tests // SSL-Detect tests
testFlags <Protocol> (0, testFlags <Protocol> (0,
MultiSocket::Flag::ssl); MultiSocket::Flag::ssl);

View File

@@ -64,7 +64,7 @@ public:
private: private:
int m_flags; int m_flags;
}; };
enum Flags enum Flags
{ {
none = 0, none = 0,
@@ -74,7 +74,19 @@ public:
server_proxy = 8 server_proxy = 8
}; };
virtual SSL* native_handle () = 0;
/*
// This would be the underlying StreamSocket template parameter
typedef boost::asio::ip::tcp::socket NativeSocketType;
typedef boost::asio::ssl::stream <NativeSocketType&> SslStreamType;
virtual boost::asio::ip::tcp::socket& getNativeSocket () = 0;
virtual SslStreamType& getSslStream () = 0;
*/
static MultiSocket* New (boost::asio::io_service& io_service, int flags = 0); static MultiSocket* New (boost::asio::io_service& io_service, int flags = 0);
static MultiSocket* New (boost::asio::ip::tcp::socket& socket, int flags = 0);
static MultiSocket* New (boost::asio::ssl::stream <boost::asio::ip::tcp::socket&>& stream, int flags = 0);
// Ripple uses a SSL/TLS context with specific parameters and this returns // Ripple uses a SSL/TLS context with specific parameters and this returns
// a reference to the corresponding boost::asio::ssl::context object. // a reference to the corresponding boost::asio::ssl::context object.

View File

@@ -7,6 +7,8 @@
#ifndef RIPPLE_MULTISOCKETTYPE_H_INCLUDED #ifndef RIPPLE_MULTISOCKETTYPE_H_INCLUDED
#define RIPPLE_MULTISOCKETTYPE_H_INCLUDED #define RIPPLE_MULTISOCKETTYPE_H_INCLUDED
#define MULTISOCKET_USE_STRAND 1
/** Template for producing instances of MultiSocket /** Template for producing instances of MultiSocket
*/ */
template <class StreamSocket> template <class StreamSocket>
@@ -20,16 +22,16 @@ public:
typedef typename boost::remove_reference <StreamSocket>::type next_layer_type; typedef typename boost::remove_reference <StreamSocket>::type next_layer_type;
typedef typename boost::add_reference <next_layer_type>::type next_layer_type_ref; typedef typename boost::add_reference <next_layer_type>::type next_layer_type_ref;
typedef typename next_layer_type::lowest_layer_type lowest_layer_type; typedef typename next_layer_type::lowest_layer_type lowest_layer_type;
typedef typename boost::asio::ssl::stream <next_layer_type_ref> ssl_stream_type;
template <class Arg> template <class Arg>
explicit MultiSocketType (Arg& arg, int flags = 0) explicit MultiSocketType (Arg& arg, int flags = 0)
: m_flags (flags) : m_flags (flags)
, m_state (stateNone) , m_state (stateNone)
, m_verify_mode (0)
, m_stream (nullptr) , m_stream (nullptr)
, m_needsShutdown (false) , m_needsShutdown (false)
, m_next_layer (arg) , m_next_layer (arg)
, m_strand (get_io_service ()) , m_native_ssl_handle (nullptr)
{ {
// See if our flags allow us to go directly // See if our flags allow us to go directly
// into the ready state with an active stream. // into the ready state with an active stream.
@@ -50,6 +52,20 @@ protected:
}; };
//-------------------------------------------------------------------------- //--------------------------------------------------------------------------
//
// MultiSocket
//
SSL* native_handle ()
{
bassert (m_native_ssl_handle != nullptr);
return m_native_ssl_handle;
}
//--------------------------------------------------------------------------
//
// MultiSocketType
//
/** The current stream we are passing everything through. /** The current stream we are passing everything through.
This object gets dynamically created and replaced with other This object gets dynamically created and replaced with other
@@ -61,6 +77,7 @@ protected:
return *m_stream; return *m_stream;
} }
#if 0
next_layer_type& next_layer () noexcept next_layer_type& next_layer () noexcept
{ {
return m_next_layer; return m_next_layer;
@@ -80,6 +97,30 @@ protected:
{ {
return m_next_layer.lowest_layer (); return m_next_layer.lowest_layer ();
} }
#endif
//--------------------------------------------------------------------------
//
// Socket
//
// We will return the underlying StreamSocket
void* this_layer_ptr (char const* type_name) const
{
char const* const name (typeid (next_layer_type).name ());
if (strcmp (name, type_name) == 0)
return const_cast <void*> (static_cast <void const*> (&m_next_layer));
return nullptr;
}
//--------------------------------------------------------------------------
// What would we return from here?
bool native_handle (char const*, void*) const
{
pure_virtual_called (__FILE__, __LINE__);
return false;
}
//-------------------------------------------------------------------------- //--------------------------------------------------------------------------
// //
@@ -96,35 +137,29 @@ protected:
// basic_socket // basic_socket
// //
void* lowest_layer (char const* type_name) const /** Always the lowest_layer of the underlying StreamSocket. */
void* lowest_layer_ptr (char const* type_name) const
{ {
char const* const name (typeid (lowest_layer_type).name ()); char const* const name (typeid (lowest_layer_type).name ());
if (strcmp (name, type_name) == 0) if (strcmp (name, type_name) == 0)
return const_cast <void*> (static_cast <void const*>(&lowest_layer ())); return const_cast <void*> (
return nullptr; static_cast <void const*>(&m_next_layer.lowest_layer ()));
}
void* native_handle (char const* type_name) const
{
char const* const name (typeid (next_layer_type).name ());
if (strcmp (name, type_name) == 0)
return const_cast <void*> (static_cast <void const*>(&next_layer ()));
return nullptr; return nullptr;
} }
error_code cancel (error_code& ec) error_code cancel (error_code& ec)
{ {
return lowest_layer ().cancel (ec); return m_next_layer.lowest_layer ().cancel (ec);
} }
error_code shutdown (Socket::shutdown_type what, error_code& ec) error_code shutdown (Socket::shutdown_type what, error_code& ec)
{ {
return lowest_layer ().shutdown (what, ec); return m_next_layer.lowest_layer ().shutdown (what, ec);
} }
error_code close (error_code& ec) error_code close (error_code& ec)
{ {
return lowest_layer ().close (ec); return m_next_layer.lowest_layer ().close (ec);
} }
//-------------------------------------------------------------------------- //--------------------------------------------------------------------------
@@ -144,26 +179,14 @@ protected:
void async_read_some (MutableBuffers const& buffers, SharedHandlerPtr handler) void async_read_some (MutableBuffers const& buffers, SharedHandlerPtr handler)
{ {
#if 0 stream ().async_read_some (buffers,
stream ().async_read_some (buffeers,
BOOST_ASIO_MOVE_CAST(SharedHandlerPtr)(handler)); BOOST_ASIO_MOVE_CAST(SharedHandlerPtr)(handler));
#else
stream ().async_read_some (buffers, m_strand.wrap (
boost::bind (handler, boost::asio::placeholders::error,
boost::asio::placeholders::bytes_transferred)));
#endif
} }
void async_write_some (ConstBuffers const& buffers, SharedHandlerPtr handler) void async_write_some (ConstBuffers const& buffers, SharedHandlerPtr handler)
{ {
#if 0
stream ().async_write_some (buffers, stream ().async_write_some (buffers,
BOOST_ASIO_MOVE_CAST(SharedHandlerPtr)(handler)); BOOST_ASIO_MOVE_CAST(SharedHandlerPtr)(handler));
#else
stream ().async_write_some (buffers, m_strand.wrap (
boost::bind ( handler, boost::asio::placeholders::error,
boost::asio::placeholders::bytes_transferred)));
#endif
} }
//-------------------------------------------------------------------------- //--------------------------------------------------------------------------
@@ -171,6 +194,19 @@ protected:
// ssl::stream // ssl::stream
// //
/** Retrieve the next_layer.
The next_layer_type is always the StreamSocket, regardless of what
we have put around it in terms of SSL or buffering.
*/
void* next_layer_ptr (char const* type_name) const
{
char const* const name (typeid (next_layer_type).name ());
if (strcmp (name, type_name) == 0)
return const_cast <void*> (static_cast <void const*> (&m_next_layer));
return nullptr;
}
/** Determine if the caller needs to call a handshaking function. /** Determine if the caller needs to call a handshaking function.
This is also used to determine if the handshaking shutdown() This is also used to determine if the handshaking shutdown()
has to be called. has to be called.
@@ -184,6 +220,16 @@ protected:
//-------------------------------------------------------------------------- //--------------------------------------------------------------------------
void set_verify_mode (int verify_mode)
{
if (m_ssl_stream != nullptr)
m_ssl_stream->set_verify_mode (verify_mode);
else
m_verify_mode = verify_mode;
}
//--------------------------------------------------------------------------
error_code handshake (Socket::handshake_type type, error_code& ec) error_code handshake (Socket::handshake_type type, error_code& ec)
{ {
return handshake (type, ConstBuffers (), ec); return handshake (type, ConstBuffers (), ec);
@@ -215,7 +261,7 @@ protected:
ConstBuffers const& buffers, SharedHandlerPtr const& handler) ConstBuffers const& buffers, SharedHandlerPtr const& handler)
{ {
return get_io_service ().dispatch (SharedHandlerPtr ( return get_io_service ().dispatch (SharedHandlerPtr (
new AsyncOp (this, type, buffers, handler))); new AsyncOp <next_layer_type> (*this, m_next_layer, type, buffers, handler)));
} }
//-------------------------------------------------------------------------- //--------------------------------------------------------------------------
@@ -252,7 +298,7 @@ protected:
{ {
// Our interface didn't require a shutdown but someone called // Our interface didn't require a shutdown but someone called
// it anyone so generate an error code. // it anyone so generate an error code.
bassertfalse; //bassertfalse;
ec = handshake_error (); ec = handshake_error ();
} }
@@ -285,29 +331,18 @@ protected:
if (flags.set (Flag::ssl_required)) if (flags.set (Flag::ssl_required))
flags = flags.without (Flag::ssl); flags = flags.without (Flag::ssl);
} }
else
{
// if not client or server, clear out all the flags
flags = Flag::peer;
}
return flags; return flags;
} }
bool is_server () const noexcept
{
return m_flags.set (Flag::server_role);
}
bool is_client () const noexcept bool is_client () const noexcept
{ {
return m_flags.set (Flag::client_role); return m_flags.set (Flag::client_role);
} }
bool is_peer () const noexcept bool is_server () const noexcept
{ {
bassert (is_client () || is_server () || (m_flags == Flag (Flag::peer))); return m_flags.set (Flag::server_role);
return m_flags == Flag (Flag::peer);
} }
// Bottleneck to indicate a failed handshake. // Bottleneck to indicate a failed handshake.
@@ -399,8 +434,23 @@ protected:
} }
else else
{ {
m_state = stateReady; // We will determine client/server mode
m_stream = new_plain_stream (); // at the time handshake is called.
bassert (! m_flags.set (Flag::proxy));
m_flags = m_flags.without (Flag::proxy);
if (m_flags.any_set (Flag::ssl | Flag::ssl_required))
{
// We will decide stream type at handshake time
m_state = stateHandshake;
m_stream = nullptr;
}
else
{
m_state = stateReady;
m_stream = new_plain_stream ();
}
} }
// We only set this to true in stateHandshake and // We only set this to true in stateHandshake and
@@ -441,9 +491,26 @@ protected:
return handshake_error (); return handshake_error ();
} }
// Peer roles cannot handshake. // Set flags based on handshake if necessary
if (m_flags == Flag (Flag::peer)) if (! m_flags.any_set (Flag::client_role | Flag::server_role))
return handshake_error (); {
switch (type)
{
case client:
m_flags = m_flags.with (Flag::client_role);
break;
case server:
m_flags = m_flags.with (Flag::server_role);
break;
default:
fatal_error ("invalid tyoe");
break;
}
m_flags = cleaned_flags (m_flags);
}
// Handshake type must match the role flags // Handshake type must match the role flags
if ( (type == client && ! is_client ()) || if ( (type == client && ! is_client ()) ||
@@ -502,11 +569,24 @@ protected:
//-------------------------------------------------------------------------- //--------------------------------------------------------------------------
template <typename Stream>
void set_ssl_stream (boost::asio::ssl::stream <Stream>& ssl_stream)
{
m_ssl_stream = new SocketWrapper <boost::asio::ssl::stream <Stream>&>
(ssl_stream);
m_ssl_stream->set_verify_mode (m_verify_mode);
m_native_ssl_handle = ssl_stream.native_handle ();
}
//--------------------------------------------------------------------------
// Create a plain stream that just wraps the next layer. // Create a plain stream that just wraps the next layer.
// //
Socket* new_plain_stream () Socket* new_plain_stream ()
{ {
return new SocketWrapper <next_layer_type&> (m_next_layer); typedef SocketWrapper <next_layer_type&> Wrapper;
Wrapper* const socket (new Wrapper (m_next_layer));
return socket;
} }
// Create a plain stream but front-load it with some bytes. // Create a plain stream but front-load it with some bytes.
@@ -517,8 +597,10 @@ protected:
{ {
if (boost::asio::buffer_size (buffers) > 0) if (boost::asio::buffer_size (buffers) > 0)
{ {
typedef PrefilledReadStream <next_layer_type&> this_type; typedef SocketWrapper <PrefilledReadStream <next_layer_type&> > Wrapper;
return new SocketWrapper <this_type> (m_next_layer, buffers); Wrapper* const socket (new Wrapper (m_next_layer));
socket->this_layer ().fill (buffers);
return socket;
} }
return new_plain_stream (); return new_plain_stream ();
} }
@@ -527,9 +609,16 @@ protected:
// //
Socket* new_ssl_stream () Socket* new_ssl_stream ()
{ {
typedef typename boost::asio::ssl::stream <next_layer_type&> this_type; typedef typename boost::asio::ssl::stream <next_layer_type&> SslStream;
return new SocketWrapper <this_type> ( #if MULTISOCKET_USE_STRAND
typedef SocketWrapperStrand <SslStream> Wrapper;
#else
typedef SocketWrapper <SslStream> Wrapper;
#endif
Wrapper* const socket = new Wrapper (
m_next_layer, MultiSocket::getRippleTlsBoostContext ()); m_next_layer, MultiSocket::getRippleTlsBoostContext ());
set_ssl_stream (socket->this_layer ());
return socket;
} }
// Creates an ssl stream, but front-load it with some bytes. // Creates an ssl stream, but front-load it with some bytes.
@@ -540,11 +629,17 @@ protected:
{ {
if (boost::asio::buffer_size (buffers) > 0) if (boost::asio::buffer_size (buffers) > 0)
{ {
typedef PrefilledReadStream <next_layer_type&> next_type; typedef boost::asio::ssl::stream <
typedef boost::asio::ssl::stream <next_type> this_type; PrefilledReadStream <next_layer_type&> > SslStream;
SocketWrapper <this_type>* const socket = new SocketWrapper <this_type> ( #if MULTISOCKET_USE_STRAND
typedef SocketWrapperStrand <SslStream> Wrapper;
#else
typedef SocketWrapper <SslStream> Wrapper;
#endif
Wrapper* const socket = new Wrapper (
m_next_layer, MultiSocket::getRippleTlsBoostContext ()); m_next_layer, MultiSocket::getRippleTlsBoostContext ());
socket->this_layer().next_layer().fill (buffers); socket->this_layer ().next_layer().fill (buffers);
set_ssl_stream (socket->this_layer ());
return socket; return socket;
} }
return new_ssl_stream (); return new_ssl_stream ();
@@ -691,16 +786,18 @@ protected:
// Composed asynchronous handshake operator // Composed asynchronous handshake operator
// //
template <typename Stream>
struct AsyncOp : ComposedAsyncOperation struct AsyncOp : ComposedAsyncOperation
{ {
typedef SharedHandlerAllocator <char> Allocator; typedef SharedHandlerAllocator <char> Allocator;
AsyncOp (MultiSocketType <StreamSocket>* socket, AsyncOp (MultiSocketType <StreamSocket>& owner, Stream& stream,
handshake_type type, ConstBuffers const& buffers, handshake_type type, ConstBuffers const& buffers,
SharedHandlerPtr const& handler) SharedHandlerPtr const& handler)
: ComposedAsyncOperation (sizeof (*this), handler) : ComposedAsyncOperation (sizeof (*this), handler)
, m_handler (handler) , m_handler (handler)
, m_socket (socket) , m_owner (owner)
, m_stream (stream)
, m_type (type) , m_type (type)
, m_buffer (std::numeric_limits <std::size_t>::max(), handler) , m_buffer (std::numeric_limits <std::size_t>::max(), handler)
, m_running (false) , m_running (false)
@@ -726,22 +823,22 @@ protected:
{ {
m_running = true; m_running = true;
error_code ec (m_socket->initHandshake (m_type)); error_code ec (m_owner.initHandshake (m_type));
if (! ec) if (! ec)
{ {
if (m_socket->m_state != stateReady) if (m_owner.m_state != stateReady)
{ {
(*this) (ec); (*this) (ec);
return; return;
} }
// Always need shutdown if handshake successful. // Always need shutdown if handshake successful.
m_socket->m_needsShutdown = true; m_owner.m_needsShutdown = true;
} }
// Call the original handler with the error code and end. // Call the original handler with the error code and end.
m_socket->get_io_service ().wrap ( m_owner.get_io_service ().wrap (
BOOST_ASIO_MOVE_CAST (SharedHandlerPtr)(m_handler)) BOOST_ASIO_MOVE_CAST (SharedHandlerPtr)(m_handler))
(ec); (ec);
} }
@@ -756,22 +853,22 @@ protected:
do do
{ {
if (m_socket->m_state == stateReady) if (m_owner.m_state == stateReady)
{ {
// Always need shutdown if handshake successful. // Always need shutdown if handshake successful.
m_socket->m_needsShutdown = true; m_owner.m_needsShutdown = true;
break; break;
} }
switch (m_socket->m_state) switch (m_owner.m_state)
{ {
case stateHandshakeFinal: case stateHandshakeFinal:
{ {
// Have to set this beforehand even // Have to set this beforehand even
// though we might get an error. // though we might get an error.
// //
m_socket->m_state = stateReady; m_owner.m_state = stateReady;
m_socket->stream ().async_handshake (m_type, m_owner.stream ().async_handshake (m_type,
SharedHandlerPtr (this)); SharedHandlerPtr (this));
} }
return; return;
@@ -782,22 +879,22 @@ protected:
{ {
if (m_proxy.getLogic ().success ()) if (m_proxy.getLogic ().success ())
{ {
m_socket->setProxyInfo (m_proxy.getLogic ().getInfo ()); m_owner.setProxyInfo (m_proxy.getLogic ().getInfo ());
// Strip off the PROXY flag. // Strip off the PROXY flag.
m_socket->m_flags = m_socket->m_flags.without (Flag::proxy); m_owner.m_flags = m_owner.m_flags.without (Flag::proxy);
// Update handshake state with the leftover bytes. // Update handshake state with the leftover bytes.
ec = m_socket->initHandshake (m_type, m_buffer.data ()); ec = m_owner.initHandshake (m_type, m_buffer.data ());
break; break;
} }
// Didn't get the PROXY handshake we needed // Didn't get the PROXY handshake we needed
ec = m_socket->handshake_error (); ec = m_owner.handshake_error ();
break; break;
} }
m_proxy.async_detect (m_socket->next_layer (), m_proxy.async_detect (m_stream,
m_buffer, SharedHandlerPtr (this)); m_buffer, SharedHandlerPtr (this));
} }
return; return;
@@ -810,21 +907,21 @@ protected:
if (m_ssl.getLogic ().success ()) if (m_ssl.getLogic ().success ())
{ {
// Convert the ssl flag to ssl_required // Convert the ssl flag to ssl_required
m_socket->m_flags = m_socket->m_flags.with ( m_owner.m_flags = m_owner.m_flags.with (
Flag::ssl_required).without (Flag::ssl); Flag::ssl_required).without (Flag::ssl);
} }
else else
{ {
// Not SSL, strip the ssl flag // Not SSL, strip the ssl flag
m_socket->m_flags = m_socket->m_flags.without (Flag::ssl); m_owner.m_flags = m_owner.m_flags.without (Flag::ssl);
} }
// Update handshake state with the leftover bytes. // Update handshake state with the leftover bytes.
ec = m_socket->initHandshake (m_type, m_buffer.data ()); ec = m_owner.initHandshake (m_type, m_buffer.data ());
break; break;
} }
m_ssl.async_detect (m_socket->next_layer (), m_ssl.async_detect (m_stream,
m_buffer, SharedHandlerPtr (this)); m_buffer, SharedHandlerPtr (this));
} }
return; return;
@@ -834,15 +931,15 @@ protected:
case stateHandshake: case stateHandshake:
default: default:
fatal_error ("invalid state"); fatal_error ("invalid state");
ec = m_socket->handshake_error (); ec = m_owner.handshake_error ();
} }
} }
while (! ec); while (! ec);
bassert (ec || (m_socket->m_state == stateReady && m_socket->m_needsShutdown)); bassert (ec || (m_owner.m_state == stateReady && m_owner.m_needsShutdown));
// Call the original handler with the error code and end. // Call the original handler with the error code and end.
m_socket->get_io_service ().wrap ( m_owner.get_io_service ().wrap (
BOOST_ASIO_MOVE_CAST (SharedHandlerPtr)(m_handler)) BOOST_ASIO_MOVE_CAST (SharedHandlerPtr)(m_handler))
(ec); (ec);
} }
@@ -854,7 +951,8 @@ protected:
private: private:
SharedHandlerPtr m_handler; SharedHandlerPtr m_handler;
MultiSocketType <StreamSocket>* m_socket; MultiSocketType <StreamSocket>& m_owner;
Stream& m_stream;
handshake_type const m_type; handshake_type const m_type;
boost::asio::basic_streambuf <Allocator> m_buffer; boost::asio::basic_streambuf <Allocator> m_buffer;
HandshakeDetectorType <next_layer_type, HandshakeDetectLogicPROXY> m_proxy; HandshakeDetectorType <next_layer_type, HandshakeDetectLogicPROXY> m_proxy;
@@ -865,10 +963,14 @@ protected:
private: private:
Flag m_flags; Flag m_flags;
State m_state; State m_state;
int m_verify_mode;
ScopedPointer <Socket> m_stream; ScopedPointer <Socket> m_stream;
ScopedPointer <Socket> m_ssl_stream; // the ssl portion of our stream if it exists
bool m_needsShutdown; bool m_needsShutdown;
StreamSocket m_next_layer; StreamSocket m_next_layer;
boost::asio::io_service::strand m_strand;
SSL* m_native_ssl_handle;
}; };
#endif #endif