Update MultiSocket for improved Socket

This commit is contained in:
Vinnie Falco
2013-08-24 18:00:08 -07:00
parent db4a1dfaa4
commit 7b9a5e8753
3 changed files with 211 additions and 85 deletions

View File

@@ -9,6 +9,20 @@ MultiSocket* MultiSocket::New (boost::asio::io_service& io_service, int flags)
return new MultiSocketType <boost::asio::ip::tcp::socket> (io_service, flags);
}
MultiSocket* MultiSocket::New (boost::asio::ip::tcp::socket& socket, int flags)
{
return new MultiSocketType <boost::asio::ip::tcp::socket&> (socket, flags);
}
MultiSocket* MultiSocket::New (
boost::asio::ssl::stream <
boost::asio::ip::tcp::socket&>& stream, int flags)
{
return new MultiSocketType <
boost::asio::ssl::stream <
boost::asio::ip::tcp::socket&>& > (stream, flags);
}
struct RippleTlsContextHolder
{
RippleTlsContextHolder ()
@@ -181,7 +195,6 @@ public:
TestPeerLogicSyncClient, TestPeerLogicAsyncServer>
(clientArg, serverArg, timeoutSeconds).report (*this);
PeerTest::run <MultiSocketDetailsType <Protocol>,
TestPeerLogicAsyncClient, TestPeerLogicAsyncServer>
(clientArg, serverArg, timeoutSeconds).report (*this);
@@ -226,7 +239,6 @@ public:
testFlags <Protocol> (MultiSocket::Flag::ssl,
MultiSocket::Flag::ssl_required);
// SSL-Detect tests
testFlags <Protocol> (0,
MultiSocket::Flag::ssl);

View File

@@ -64,7 +64,7 @@ public:
private:
int m_flags;
};
enum Flags
{
none = 0,
@@ -74,7 +74,19 @@ public:
server_proxy = 8
};
virtual SSL* native_handle () = 0;
/*
// This would be the underlying StreamSocket template parameter
typedef boost::asio::ip::tcp::socket NativeSocketType;
typedef boost::asio::ssl::stream <NativeSocketType&> SslStreamType;
virtual boost::asio::ip::tcp::socket& getNativeSocket () = 0;
virtual SslStreamType& getSslStream () = 0;
*/
static MultiSocket* New (boost::asio::io_service& io_service, int flags = 0);
static MultiSocket* New (boost::asio::ip::tcp::socket& socket, int flags = 0);
static MultiSocket* New (boost::asio::ssl::stream <boost::asio::ip::tcp::socket&>& stream, int flags = 0);
// Ripple uses a SSL/TLS context with specific parameters and this returns
// a reference to the corresponding boost::asio::ssl::context object.

View File

@@ -7,6 +7,8 @@
#ifndef RIPPLE_MULTISOCKETTYPE_H_INCLUDED
#define RIPPLE_MULTISOCKETTYPE_H_INCLUDED
#define MULTISOCKET_USE_STRAND 1
/** Template for producing instances of MultiSocket
*/
template <class StreamSocket>
@@ -20,16 +22,16 @@ public:
typedef typename boost::remove_reference <StreamSocket>::type next_layer_type;
typedef typename boost::add_reference <next_layer_type>::type next_layer_type_ref;
typedef typename next_layer_type::lowest_layer_type lowest_layer_type;
typedef typename boost::asio::ssl::stream <next_layer_type_ref> ssl_stream_type;
template <class Arg>
explicit MultiSocketType (Arg& arg, int flags = 0)
: m_flags (flags)
, m_state (stateNone)
, m_verify_mode (0)
, m_stream (nullptr)
, m_needsShutdown (false)
, m_next_layer (arg)
, m_strand (get_io_service ())
, m_native_ssl_handle (nullptr)
{
// See if our flags allow us to go directly
// into the ready state with an active stream.
@@ -50,6 +52,20 @@ protected:
};
//--------------------------------------------------------------------------
//
// MultiSocket
//
SSL* native_handle ()
{
bassert (m_native_ssl_handle != nullptr);
return m_native_ssl_handle;
}
//--------------------------------------------------------------------------
//
// MultiSocketType
//
/** The current stream we are passing everything through.
This object gets dynamically created and replaced with other
@@ -61,6 +77,7 @@ protected:
return *m_stream;
}
#if 0
next_layer_type& next_layer () noexcept
{
return m_next_layer;
@@ -80,6 +97,30 @@ protected:
{
return m_next_layer.lowest_layer ();
}
#endif
//--------------------------------------------------------------------------
//
// Socket
//
// We will return the underlying StreamSocket
void* this_layer_ptr (char const* type_name) const
{
char const* const name (typeid (next_layer_type).name ());
if (strcmp (name, type_name) == 0)
return const_cast <void*> (static_cast <void const*> (&m_next_layer));
return nullptr;
}
//--------------------------------------------------------------------------
// What would we return from here?
bool native_handle (char const*, void*) const
{
pure_virtual_called (__FILE__, __LINE__);
return false;
}
//--------------------------------------------------------------------------
//
@@ -96,35 +137,29 @@ protected:
// basic_socket
//
void* lowest_layer (char const* type_name) const
/** Always the lowest_layer of the underlying StreamSocket. */
void* lowest_layer_ptr (char const* type_name) const
{
char const* const name (typeid (lowest_layer_type).name ());
if (strcmp (name, type_name) == 0)
return const_cast <void*> (static_cast <void const*>(&lowest_layer ()));
return nullptr;
}
void* native_handle (char const* type_name) const
{
char const* const name (typeid (next_layer_type).name ());
if (strcmp (name, type_name) == 0)
return const_cast <void*> (static_cast <void const*>(&next_layer ()));
return const_cast <void*> (
static_cast <void const*>(&m_next_layer.lowest_layer ()));
return nullptr;
}
error_code cancel (error_code& ec)
{
return lowest_layer ().cancel (ec);
return m_next_layer.lowest_layer ().cancel (ec);
}
error_code shutdown (Socket::shutdown_type what, error_code& ec)
{
return lowest_layer ().shutdown (what, ec);
return m_next_layer.lowest_layer ().shutdown (what, ec);
}
error_code close (error_code& ec)
{
return lowest_layer ().close (ec);
return m_next_layer.lowest_layer ().close (ec);
}
//--------------------------------------------------------------------------
@@ -144,26 +179,14 @@ protected:
void async_read_some (MutableBuffers const& buffers, SharedHandlerPtr handler)
{
#if 0
stream ().async_read_some (buffeers,
stream ().async_read_some (buffers,
BOOST_ASIO_MOVE_CAST(SharedHandlerPtr)(handler));
#else
stream ().async_read_some (buffers, m_strand.wrap (
boost::bind (handler, boost::asio::placeholders::error,
boost::asio::placeholders::bytes_transferred)));
#endif
}
void async_write_some (ConstBuffers const& buffers, SharedHandlerPtr handler)
{
#if 0
stream ().async_write_some (buffers,
BOOST_ASIO_MOVE_CAST(SharedHandlerPtr)(handler));
#else
stream ().async_write_some (buffers, m_strand.wrap (
boost::bind ( handler, boost::asio::placeholders::error,
boost::asio::placeholders::bytes_transferred)));
#endif
}
//--------------------------------------------------------------------------
@@ -171,6 +194,19 @@ protected:
// ssl::stream
//
/** Retrieve the next_layer.
The next_layer_type is always the StreamSocket, regardless of what
we have put around it in terms of SSL or buffering.
*/
void* next_layer_ptr (char const* type_name) const
{
char const* const name (typeid (next_layer_type).name ());
if (strcmp (name, type_name) == 0)
return const_cast <void*> (static_cast <void const*> (&m_next_layer));
return nullptr;
}
/** Determine if the caller needs to call a handshaking function.
This is also used to determine if the handshaking shutdown()
has to be called.
@@ -184,6 +220,16 @@ protected:
//--------------------------------------------------------------------------
void set_verify_mode (int verify_mode)
{
if (m_ssl_stream != nullptr)
m_ssl_stream->set_verify_mode (verify_mode);
else
m_verify_mode = verify_mode;
}
//--------------------------------------------------------------------------
error_code handshake (Socket::handshake_type type, error_code& ec)
{
return handshake (type, ConstBuffers (), ec);
@@ -215,7 +261,7 @@ protected:
ConstBuffers const& buffers, SharedHandlerPtr const& handler)
{
return get_io_service ().dispatch (SharedHandlerPtr (
new AsyncOp (this, type, buffers, handler)));
new AsyncOp <next_layer_type> (*this, m_next_layer, type, buffers, handler)));
}
//--------------------------------------------------------------------------
@@ -252,7 +298,7 @@ protected:
{
// Our interface didn't require a shutdown but someone called
// it anyone so generate an error code.
bassertfalse;
//bassertfalse;
ec = handshake_error ();
}
@@ -285,29 +331,18 @@ protected:
if (flags.set (Flag::ssl_required))
flags = flags.without (Flag::ssl);
}
else
{
// if not client or server, clear out all the flags
flags = Flag::peer;
}
return flags;
}
bool is_server () const noexcept
{
return m_flags.set (Flag::server_role);
}
bool is_client () const noexcept
{
return m_flags.set (Flag::client_role);
}
bool is_peer () const noexcept
bool is_server () const noexcept
{
bassert (is_client () || is_server () || (m_flags == Flag (Flag::peer)));
return m_flags == Flag (Flag::peer);
return m_flags.set (Flag::server_role);
}
// Bottleneck to indicate a failed handshake.
@@ -399,8 +434,23 @@ protected:
}
else
{
m_state = stateReady;
m_stream = new_plain_stream ();
// We will determine client/server mode
// at the time handshake is called.
bassert (! m_flags.set (Flag::proxy));
m_flags = m_flags.without (Flag::proxy);
if (m_flags.any_set (Flag::ssl | Flag::ssl_required))
{
// We will decide stream type at handshake time
m_state = stateHandshake;
m_stream = nullptr;
}
else
{
m_state = stateReady;
m_stream = new_plain_stream ();
}
}
// We only set this to true in stateHandshake and
@@ -441,9 +491,26 @@ protected:
return handshake_error ();
}
// Peer roles cannot handshake.
if (m_flags == Flag (Flag::peer))
return handshake_error ();
// Set flags based on handshake if necessary
if (! m_flags.any_set (Flag::client_role | Flag::server_role))
{
switch (type)
{
case client:
m_flags = m_flags.with (Flag::client_role);
break;
case server:
m_flags = m_flags.with (Flag::server_role);
break;
default:
fatal_error ("invalid tyoe");
break;
}
m_flags = cleaned_flags (m_flags);
}
// Handshake type must match the role flags
if ( (type == client && ! is_client ()) ||
@@ -502,11 +569,24 @@ protected:
//--------------------------------------------------------------------------
template <typename Stream>
void set_ssl_stream (boost::asio::ssl::stream <Stream>& ssl_stream)
{
m_ssl_stream = new SocketWrapper <boost::asio::ssl::stream <Stream>&>
(ssl_stream);
m_ssl_stream->set_verify_mode (m_verify_mode);
m_native_ssl_handle = ssl_stream.native_handle ();
}
//--------------------------------------------------------------------------
// Create a plain stream that just wraps the next layer.
//
Socket* new_plain_stream ()
{
return new SocketWrapper <next_layer_type&> (m_next_layer);
typedef SocketWrapper <next_layer_type&> Wrapper;
Wrapper* const socket (new Wrapper (m_next_layer));
return socket;
}
// Create a plain stream but front-load it with some bytes.
@@ -517,8 +597,10 @@ protected:
{
if (boost::asio::buffer_size (buffers) > 0)
{
typedef PrefilledReadStream <next_layer_type&> this_type;
return new SocketWrapper <this_type> (m_next_layer, buffers);
typedef SocketWrapper <PrefilledReadStream <next_layer_type&> > Wrapper;
Wrapper* const socket (new Wrapper (m_next_layer));
socket->this_layer ().fill (buffers);
return socket;
}
return new_plain_stream ();
}
@@ -527,9 +609,16 @@ protected:
//
Socket* new_ssl_stream ()
{
typedef typename boost::asio::ssl::stream <next_layer_type&> this_type;
return new SocketWrapper <this_type> (
typedef typename boost::asio::ssl::stream <next_layer_type&> SslStream;
#if MULTISOCKET_USE_STRAND
typedef SocketWrapperStrand <SslStream> Wrapper;
#else
typedef SocketWrapper <SslStream> Wrapper;
#endif
Wrapper* const socket = new Wrapper (
m_next_layer, MultiSocket::getRippleTlsBoostContext ());
set_ssl_stream (socket->this_layer ());
return socket;
}
// Creates an ssl stream, but front-load it with some bytes.
@@ -540,11 +629,17 @@ protected:
{
if (boost::asio::buffer_size (buffers) > 0)
{
typedef PrefilledReadStream <next_layer_type&> next_type;
typedef boost::asio::ssl::stream <next_type> this_type;
SocketWrapper <this_type>* const socket = new SocketWrapper <this_type> (
typedef boost::asio::ssl::stream <
PrefilledReadStream <next_layer_type&> > SslStream;
#if MULTISOCKET_USE_STRAND
typedef SocketWrapperStrand <SslStream> Wrapper;
#else
typedef SocketWrapper <SslStream> Wrapper;
#endif
Wrapper* const socket = new Wrapper (
m_next_layer, MultiSocket::getRippleTlsBoostContext ());
socket->this_layer().next_layer().fill (buffers);
socket->this_layer ().next_layer().fill (buffers);
set_ssl_stream (socket->this_layer ());
return socket;
}
return new_ssl_stream ();
@@ -691,16 +786,18 @@ protected:
// Composed asynchronous handshake operator
//
template <typename Stream>
struct AsyncOp : ComposedAsyncOperation
{
typedef SharedHandlerAllocator <char> Allocator;
AsyncOp (MultiSocketType <StreamSocket>* socket,
AsyncOp (MultiSocketType <StreamSocket>& owner, Stream& stream,
handshake_type type, ConstBuffers const& buffers,
SharedHandlerPtr const& handler)
: ComposedAsyncOperation (sizeof (*this), handler)
, m_handler (handler)
, m_socket (socket)
, m_owner (owner)
, m_stream (stream)
, m_type (type)
, m_buffer (std::numeric_limits <std::size_t>::max(), handler)
, m_running (false)
@@ -726,22 +823,22 @@ protected:
{
m_running = true;
error_code ec (m_socket->initHandshake (m_type));
error_code ec (m_owner.initHandshake (m_type));
if (! ec)
{
if (m_socket->m_state != stateReady)
if (m_owner.m_state != stateReady)
{
(*this) (ec);
return;
}
// Always need shutdown if handshake successful.
m_socket->m_needsShutdown = true;
m_owner.m_needsShutdown = true;
}
// Call the original handler with the error code and end.
m_socket->get_io_service ().wrap (
m_owner.get_io_service ().wrap (
BOOST_ASIO_MOVE_CAST (SharedHandlerPtr)(m_handler))
(ec);
}
@@ -756,22 +853,22 @@ protected:
do
{
if (m_socket->m_state == stateReady)
if (m_owner.m_state == stateReady)
{
// Always need shutdown if handshake successful.
m_socket->m_needsShutdown = true;
m_owner.m_needsShutdown = true;
break;
}
switch (m_socket->m_state)
switch (m_owner.m_state)
{
case stateHandshakeFinal:
{
// Have to set this beforehand even
// though we might get an error.
//
m_socket->m_state = stateReady;
m_socket->stream ().async_handshake (m_type,
m_owner.m_state = stateReady;
m_owner.stream ().async_handshake (m_type,
SharedHandlerPtr (this));
}
return;
@@ -782,22 +879,22 @@ protected:
{
if (m_proxy.getLogic ().success ())
{
m_socket->setProxyInfo (m_proxy.getLogic ().getInfo ());
m_owner.setProxyInfo (m_proxy.getLogic ().getInfo ());
// Strip off the PROXY flag.
m_socket->m_flags = m_socket->m_flags.without (Flag::proxy);
m_owner.m_flags = m_owner.m_flags.without (Flag::proxy);
// Update handshake state with the leftover bytes.
ec = m_socket->initHandshake (m_type, m_buffer.data ());
ec = m_owner.initHandshake (m_type, m_buffer.data ());
break;
}
// Didn't get the PROXY handshake we needed
ec = m_socket->handshake_error ();
ec = m_owner.handshake_error ();
break;
}
m_proxy.async_detect (m_socket->next_layer (),
m_proxy.async_detect (m_stream,
m_buffer, SharedHandlerPtr (this));
}
return;
@@ -810,21 +907,21 @@ protected:
if (m_ssl.getLogic ().success ())
{
// Convert the ssl flag to ssl_required
m_socket->m_flags = m_socket->m_flags.with (
m_owner.m_flags = m_owner.m_flags.with (
Flag::ssl_required).without (Flag::ssl);
}
else
{
// Not SSL, strip the ssl flag
m_socket->m_flags = m_socket->m_flags.without (Flag::ssl);
m_owner.m_flags = m_owner.m_flags.without (Flag::ssl);
}
// Update handshake state with the leftover bytes.
ec = m_socket->initHandshake (m_type, m_buffer.data ());
ec = m_owner.initHandshake (m_type, m_buffer.data ());
break;
}
m_ssl.async_detect (m_socket->next_layer (),
m_ssl.async_detect (m_stream,
m_buffer, SharedHandlerPtr (this));
}
return;
@@ -834,15 +931,15 @@ protected:
case stateHandshake:
default:
fatal_error ("invalid state");
ec = m_socket->handshake_error ();
ec = m_owner.handshake_error ();
}
}
while (! ec);
bassert (ec || (m_socket->m_state == stateReady && m_socket->m_needsShutdown));
bassert (ec || (m_owner.m_state == stateReady && m_owner.m_needsShutdown));
// Call the original handler with the error code and end.
m_socket->get_io_service ().wrap (
m_owner.get_io_service ().wrap (
BOOST_ASIO_MOVE_CAST (SharedHandlerPtr)(m_handler))
(ec);
}
@@ -854,7 +951,8 @@ protected:
private:
SharedHandlerPtr m_handler;
MultiSocketType <StreamSocket>* m_socket;
MultiSocketType <StreamSocket>& m_owner;
Stream& m_stream;
handshake_type const m_type;
boost::asio::basic_streambuf <Allocator> m_buffer;
HandshakeDetectorType <next_layer_type, HandshakeDetectLogicPROXY> m_proxy;
@@ -865,10 +963,14 @@ protected:
private:
Flag m_flags;
State m_state;
int m_verify_mode;
ScopedPointer <Socket> m_stream;
ScopedPointer <Socket> m_ssl_stream; // the ssl portion of our stream if it exists
bool m_needsShutdown;
StreamSocket m_next_layer;
boost::asio::io_service::strand m_strand;
SSL* m_native_ssl_handle;
};
#endif