Fixes to handshake and socket wrapper

This commit is contained in:
Vinnie Falco
2013-08-16 16:10:55 -07:00
parent 191562c13c
commit 0a58333edc
10 changed files with 292 additions and 242 deletions

View File

@@ -35,6 +35,9 @@ public:
template <class Arg>
explicit MultiSocketType (Arg& arg, int flags = 0)
: m_flags (cleaned_flags (flags))
, m_needs_handshake (false)
, m_needs_shutdown (false)
, m_needs_shutdown_stream (false)
, m_next_layer (arg)
, m_stream (new_stream ())
, m_strand (get_io_service ())
@@ -42,6 +45,10 @@ public:
}
protected:
/** The current stream we are passing everything through.
This object gets dynamically created and replaced with other
objects as we process the various flags for handshaking.
*/
Socket& stream () const noexcept
{
bassert (m_stream != nullptr);
@@ -150,27 +157,42 @@ protected:
// ssl::stream
//
bool requires_handshake ()
/** Determine if the caller needs to call a handshaking function. */
bool needs_handshake ()
{
if (is_client ())
#if 1
return m_needs_handshake;
#else
if (is_server ())
{
// If the PROXY flag is on, we need that first
if (m_flags.set (Flag::proxy))
return true;
if (m_flags.any_set (Flag::ssl | Flag::ssl_required))
return true;
return false;
}
else if (is_client ())
{
// Client sends PROXY on the initial connect
// before We even get to this function.
if (m_flags.set (Flag::ssl))
return true;
return false;
}
//bassert (is_server ());
if (m_stream != nullptr)
return stream ().requires_handshake ();
return true;
bassert (is_peer ());
return false;
#endif
}
error_code handshake (Socket::handshake_type type, error_code& ec)
{
if (set_handshake_stream (type, ec))
if (init_handshake (type, ec))
return stream ().handshake (type, ec);
return ec;
}
@@ -179,17 +201,18 @@ protected:
async_handshake (handshake_type type, BOOST_ASIO_MOVE_ARG(ErrorCall) handler)
{
error_code ec;
if (set_handshake_stream (type, ec))
if (init_handshake (type, ec))
return stream ().async_handshake (type,
BOOST_ASIO_MOVE_CAST(ErrorCall)(handler));
return get_io_service ().post (CompletionCall (handler, ec));
return get_io_service ().post (CompletionCall (
BOOST_ASIO_MOVE_CAST(ErrorCall)(handler), ec));
}
#if BEAST_ASIO_HAS_BUFFEREDHANDSHAKE
error_code handshake (handshake_type type,
ConstBuffers const& buffers, error_code& ec)
{
if (set_handshake_stream (type, ec))
if (init_handshake (type, ec))
return stream ().handshake (type, buffers, ec);
return ec;
}
@@ -199,7 +222,7 @@ protected:
BOOST_ASIO_MOVE_ARG(TransferCall) handler)
{
error_code ec;
if (set_handshake_stream (type, ec))
if (init_handshake (type, ec))
return stream ().async_handshake (type, buffers,
BOOST_ASIO_MOVE_CAST(TransferCall)(handler));
return get_io_service ().post (CompletionCall (handler, ec, 0));
@@ -208,25 +231,33 @@ protected:
error_code shutdown (error_code& ec)
{
bassert (requires_handshake ());
if (stream ().requires_handshake ())
return stream ().shutdown (ec);
if (m_needs_shutdown)
{
if (m_needs_shutdown_stream)
return stream ().shutdown (ec);
return ec = error_code ();
}
return handshake_error (ec);
}
BEAST_ASIO_INITFN_RESULT_TYPE_MEMBER(ErrorCall, void (error_code))
async_shutdown (BOOST_ASIO_MOVE_ARG(ErrorCall) handler)
{
bassert (requires_handshake ());
if (stream ().requires_handshake ())
return stream ().async_shutdown (BOOST_ASIO_MOVE_CAST(ErrorCall)(handler));
error_code ec;
handshake_error (ec);
return get_io_service ().post (CompletionCall (handler, ec));
if (m_needs_shutdown)
{
if (m_needs_shutdown_stream)
return stream ().async_shutdown (
BOOST_ASIO_MOVE_CAST(ErrorCall)(handler));
}
else
{
ec = handshake_error ();
}
return get_io_service ().post (CompletionCall (
BOOST_ASIO_MOVE_CAST(ErrorCall)(handler),
ec));
}
//--------------------------------------------------------------------------
@@ -277,8 +308,7 @@ protected:
template <typename ConstBufferSequence>
Socket* new_plain_stream (ConstBufferSequence const& buffers)
{
using namespace boost;
if (asio::buffer_size (buffers) > 0)
if (boost::asio::buffer_size (buffers) > 0)
{
typedef PrefilledReadStream <next_layer_type&> this_type;
return new SocketWrapper <this_type> (m_next_layer, buffers);
@@ -290,8 +320,7 @@ protected:
//
Socket* new_ssl_stream ()
{
using namespace boost;
typedef typename asio::ssl::stream <next_layer_type&> this_type;
typedef typename boost::asio::ssl::stream <next_layer_type&> this_type;
return new SocketWrapper <this_type> (
m_next_layer, MultiSocket::getRippleTlsBoostContext ());
}
@@ -302,11 +331,10 @@ protected:
template <typename ConstBufferSequence>
Socket* new_ssl_stream (ConstBufferSequence const& buffers)
{
using namespace boost;
if (asio::buffer_size (buffers) > 0)
if (boost::asio::buffer_size (buffers) > 0)
{
typedef PrefilledReadStream <next_layer_type&> next_type;
typedef asio::ssl::stream <next_type> this_type;
typedef boost::asio::ssl::stream <next_type> this_type;
SocketWrapper <this_type>* const socket = new SocketWrapper <this_type> (
m_next_layer, MultiSocket::getRippleTlsBoostContext ());
socket->this_layer().next_layer().fill (buffers);
@@ -325,11 +353,12 @@ protected:
return new SocketWrapper <this_type> (callback, m_next_layer);
}
// Creates an ssl detect stream and front-loads it with data.
//
template <typename ConstBufferSequence>
Socket* new_detect_ssl_stream (ConstBufferSequence const& buffers)
{
using namespace boost;
if (asio::buffer_size (buffers) > 0)
if (boost::asio::buffer_size (buffers) > 0)
{
SSL3DetectCallback* callback = new SSL3DetectCallback (this);
typedef HandshakeDetectStreamType <next_layer_type&, HandshakeDetectLogicSSL3> this_type;
@@ -348,7 +377,7 @@ protected:
Socket* new_proxy_stream (ConstBufferSequence const& buffers)
{
// Doesn't make sense to already have data if we're expecting PROXY
bassert (boost::asio::buffer_size (buffers) == 0);
check_precondition (boost::asio::buffer_size (buffers) == 0);
PROXYDetectCallback* callback = new PROXYDetectCallback (this);
typedef HandshakeDetectStreamType <next_layer_type&, HandshakeDetectLogicPROXY> this_type;
@@ -357,10 +386,12 @@ protected:
//--------------------------------------------------------------------------
// Returns a stream appropriate for the current settings. If the
// return value is nullptr, it means that the stream cannot be determined
// until handshake or async_handshake is called. At that point, it will
// be necessary to use new_handshake_stream.
// Called at construction time, this returns a stream appropriate for
// the current settings. If the return value is nullptr, it means that
// the stream cannot be determined until handshake or async_handshake
// is called. In that case, init_handshake will set m_stream.
//
// This also sets the m_needs_handshake and m_needs_shutdown flags
//
Socket* new_stream ()
{
@@ -370,46 +401,56 @@ protected:
{
if (m_flags.set (Flag::proxy))
{
// Client sends PROXY in the plain
// Client sends PROXY in the plain so make
// sure they have an underlying stream right away.
socket = new_plain_stream ();
m_needs_handshake = m_flags.set (Flag::ssl);
}
else if (! m_flags.any_set (
Flag::ssl))
else if (m_flags.set (Flag::ssl))
{
socket = new_plain_stream ();
socket = nullptr;
m_needs_handshake = true;
}
else
{
// A call to handshake or async_handshake is
// required to determine the type of stream to create.
//
socket = nullptr;
socket = new_plain_stream ();
m_needs_handshake = false;
}
}
else if (is_server ())
{
// As long as no flags indicate a handshake may be
// required, we can start out with a plain stream.
//
if (! m_flags.any_set (
Flag::ssl | Flag::ssl_required | Flag::proxy))
if (m_flags.set (Flag::proxy))
{
socket = new_plain_stream ();
// PROXY comes first in the plain
socket = nullptr;
m_needs_handshake = true;
}
else if (m_flags.set (Flag::ssl_required))
{
socket = nullptr;
m_needs_handshake = true;
}
else if (m_flags.set (Flag::ssl))
{
socket = nullptr;
m_needs_handshake = true;
}
else
{
// A call to handshake or async_handshake is
// required to determine the type of stream to create.
//
socket = nullptr;
m_needs_handshake = false;
socket = new_plain_stream ();
}
}
else
{
// peer, plain stream, no handshake
m_needs_handshake = false;
socket = new_plain_stream ();
}
// We only set this to true when the handshake succeeds.
m_needs_shutdown = false;
return socket;
}
@@ -417,12 +458,18 @@ protected:
// Bottleneck to indicate a failed handshake.
//
error_code handshake_error (error_code& ec)
error_code handshake_error ()
{
// VFALCO TODO maybe use a ripple error category?
// set this to something custom that we can recognize later?
//
return ec = boost::asio::error::invalid_argument;
m_needs_shutdown = false;
return boost::asio::error::invalid_argument;
}
error_code handshake_error (error_code& ec)
{
return ec = handshake_error ();
}
// Sanity checks and cleans up our flags based on the desired handshake
@@ -430,35 +477,23 @@ protected:
// If the requested handshake type is incompatible with the socket flags
// set on construction, an error is returned.
//
// Returns true if the stream was set
// This also clears the m_needs_handshake flag on an error.
//
bool set_handshake_stream (handshake_type type,
// Returns true if the stream was set.
//
bool init_handshake (handshake_type type,
error_code& ec, ConstBuffers const& buffers = ConstBuffers())
{
// VFALCO can't really assert on this because what if
// there was a previous stream?
//
// We better require a handshake if this is getting called
//bassert (requires_handshake ());
bassert (m_needs_handshake);
// If a stream already exists, some logic went wrong.
// Can't assert on this either, for the case of going proxy->ssldetect
//bassert (m_stream == nullptr);
ec = error_code ();
// If we constructed as a peer role then return
// an error since there is no handshake required.
//
// Peer roles cannot handshake.
if (m_flags == Flag (Flag::peer))
{
handshake_error (ec);
return false;
}
// If we constructed in a server or client role and
// the handshake type doesn't match, return an error.
//
// Handshake type must match the role flags
if ( (type == client && ! is_client ()) ||
(type == server && ! is_server ()))
{
@@ -466,22 +501,21 @@ protected:
return false;
}
Socket* socket = nullptr;
ec = error_code ();
if (is_client ())
{
// can't assert on this because we dont control the send,
// only when they decide to perform the handshake.
//bassert (! m_flags.set (Flag::proxy));
// PROXY sent before handshake, in plain
m_flags = m_flags.without (Flag::proxy);
if (m_flags.set (Flag::ssl))
if (! m_flags.set (Flag::ssl))
{
socket = new_ssl_stream (buffers);
}
else
{
socket = new_plain_stream (buffers);
handshake_error (ec);
return false;
}
m_stream = new_ssl_stream (buffers);
m_needs_shutdown_stream = true;
}
else
{
@@ -489,29 +523,149 @@ protected:
if (m_flags.set (Flag::proxy))
{
socket = new_proxy_stream (buffers);
m_stream = new_proxy_stream (ConstBuffers ());
}
else if (m_flags.set (Flag::ssl_required))
{
socket = new_ssl_stream (buffers);
m_stream = new_ssl_stream (buffers);
m_needs_shutdown_stream = true;
}
else if (m_flags.set (Flag::ssl))
{
socket = new_detect_ssl_stream (buffers);
m_stream = new_detect_ssl_stream (buffers);
}
else
{
socket = new_plain_stream (buffers);
// This happens after a proxy handshake
m_stream = new_plain_stream (buffers);
}
}
check_postcondition (socket != nullptr);
m_stream = socket;
m_needs_shutdown = true;
return true;
}
//--------------------------------------------------------------------------
//
// PROXY callback
//
class PROXYDetectCallback
: public HandshakeDetectStream <HandshakeDetectLogicPROXY>::Callback
{
public:
typedef HandshakeDetectLogicPROXY LogicType;
explicit PROXYDetectCallback (MultiSocketType <StreamSocket>* owner)
: m_owner (owner)
{
}
void on_detect (LogicType& logic,
error_code& ec, ConstBuffers const& buffers)
{
m_owner->on_detect_proxy (logic, ec, buffers);
}
void on_async_detect (LogicType& logic, error_code const& ec,
ConstBuffers const& buffers, ErrorCall const& origHandler)
{
m_owner->on_async_detect_proxy (logic, ec, buffers, origHandler);
}
#if BEAST_ASIO_HAS_BUFFEREDHANDSHAKE
void on_async_detect (LogicType& logic, error_code const& ec,
ConstBuffers const& buffers, TransferCall const& origHandler)
{
m_owner->on_async_detect_proxy (logic, ec, buffers, origHandler);
}
#endif
private:
MultiSocketType <StreamSocket>* m_owner;
};
//--------------------------------------------------------------------------
void setProxyInfo (HandshakeDetectLogicPROXY::ProxyInfo const proxyInfo)
{
// Do something with it
}
void on_detect_proxy (HandshakeDetectLogicPROXY& logic,
error_code& ec, ConstBuffers const& buffers)
{
bassert (is_server ());
if (ec)
return;
if (! logic.success ())
{
handshake_error (ec);
return;
}
// Remove the proxy flag
m_flags = m_flags.without (Flag::proxy);
setProxyInfo (logic.getInfo ());
if (init_handshake (Socket::server, ec, buffers))
{
if (stream ().needs_handshake ())
stream ().handshake (Socket::server, ec);
}
}
// origHandler cannot be a reference since it gets move-copied
void on_async_detect_proxy (HandshakeDetectLogicPROXY& logic,
error_code ec, ConstBuffers const& buffers, ErrorCall origHandler)
{
if (! ec)
{
// Failed to receive PROXY handshake
if (! logic.success ())
handshake_error (ec);
}
if (! ec)
{
// Remove the proxy flag from our flag set
m_flags = m_flags.without (Flag::proxy);
setProxyInfo (logic.getInfo ());
if (init_handshake (Socket::server, ec, buffers))
{
if (stream ().needs_handshake ())
{
stream ().async_handshake (Socket::server, origHandler);
return;
}
}
}
get_io_service ().post (CompletionCall (
BOOST_ASIO_MOVE_CAST(ErrorCall)(origHandler), ec));
}
#if BEAST_ASIO_HAS_BUFFEREDHANDSHAKE
// origHandler cannot be a reference since it gets move-copied
void on_async_detect_proxy (HandshakeDetectLogicPROXY& logic,
error_code const& ec, ConstBuffers const& buffers, TransferCall origHandler)
{
std::size_t bytes_transferred = 0;
if (! ec)
{
if (logic.success ())
{
}
}
}
#endif
//--------------------------------------------------------------------------
void on_detect_ssl (HandshakeDetectLogicSSL3& logic,
@@ -533,6 +687,8 @@ protected:
#endif
m_stream = new_ssl_stream (buffers);
m_needs_shutdown_stream = true;
stream ().handshake (Socket::server, ec);
return;
}
@@ -550,6 +706,7 @@ protected:
if (logic.success ())
{
m_stream = new_ssl_stream (buffers);
m_needs_shutdown_stream = true;
// Is this a continuation?
stream ().async_handshake (Socket::server, origHandler);
return;
@@ -596,95 +753,6 @@ protected:
}
#endif
//--------------------------------------------------------------------------
void setProxyInfo (HandshakeDetectLogicPROXY::ProxyInfo const proxyInfo)
{
// Do something with it
}
void on_detect_proxy (HandshakeDetectLogicPROXY& logic,
error_code& ec, ConstBuffers const& buffers)
{
bassert (is_server ());
if (ec)
return;
if (! logic.success ())
{
handshake_error (ec);
return;
}
// Remove the proxy flag
m_flags = m_flags.without (Flag::proxy);
setProxyInfo (logic.getInfo ());
if (set_handshake_stream (Socket::server, ec, buffers))
{
if (stream ().requires_handshake ())
stream ().handshake (Socket::server, ec);
}
}
// origHandler cannot be a reference since it gets move-copied
void on_async_detect_proxy (HandshakeDetectLogicPROXY& logic,
error_code ec, ConstBuffers const& buffers, ErrorCall origHandler)
{
if (! ec)
{
if (! logic.success ())
handshake_error (ec);
}
if (! ec)
{
// Remove the proxy flag
m_flags = m_flags.without (Flag::proxy);
setProxyInfo (logic.getInfo ());
if (set_handshake_stream (Socket::server, ec, buffers))
{
if (stream ().requires_handshake ())
{
stream ().async_handshake (Socket::server, origHandler);
}
else
{
get_io_service ().post (CompletionCall (origHandler, ec));
}
return;
}
else
{
bassert (ec);
}
}
if (ec)
get_io_service ().post (CompletionCall (origHandler, ec));
}
#if BEAST_ASIO_HAS_BUFFEREDHANDSHAKE
// origHandler cannot be a reference since it gets move-copied
void on_async_detect_proxy (HandshakeDetectLogicPROXY& logic,
error_code const& ec, ConstBuffers const& buffers, TransferCall origHandler)
{
std::size_t bytes_transferred = 0;
if (! ec)
{
if (logic.success ())
{
}
}
}
#endif
//--------------------------------------------------------------------------
//
// SSL3 Callback
@@ -724,44 +792,6 @@ protected:
MultiSocketType <StreamSocket>* m_owner;
};
//--------------------------------------------------------------------------
//
// PROXY callback
//
class PROXYDetectCallback
: public HandshakeDetectStream <HandshakeDetectLogicPROXY>::Callback
{
public:
typedef HandshakeDetectLogicPROXY LogicType;
explicit PROXYDetectCallback (MultiSocketType <StreamSocket>* owner)
: m_owner (owner)
{
}
void on_detect (LogicType& logic, error_code& ec, ConstBuffers const& buffers)
{
m_owner->on_detect_proxy (logic, ec, buffers);
}
void on_async_detect (LogicType& logic, error_code const& ec,
ConstBuffers const& buffers, ErrorCall const& origHandler)
{
m_owner->on_async_detect_proxy (logic, ec, buffers, origHandler);
}
#if BEAST_ASIO_HAS_BUFFEREDHANDSHAKE
void on_async_detect (LogicType& logic, error_code const& ec,
ConstBuffers const& buffers, TransferCall const& origHandler)
{
m_owner->on_async_detect_proxy (logic, ec, buffers, origHandler);
}
#endif
private:
MultiSocketType <StreamSocket>* m_owner;
};
//--------------------------------------------------------------------------
//
// Utilities
@@ -777,8 +807,17 @@ protected:
return m_flags.set (Flag::client_role);
}
bool is_peer () const noexcept
{
bassert (is_client () || is_server () || (m_flags == Flag (Flag::peer)));
return m_flags == Flag (Flag::peer);
}
private:
Flag m_flags;
bool m_needs_handshake;
bool m_needs_shutdown;
bool m_needs_shutdown_stream;
StreamSocket m_next_layer;
ScopedPointer <Socket> m_stream;
boost::asio::io_service::strand m_strand;