Use handshake detect stream in MultiSocket

This commit is contained in:
Vinnie Falco
2013-08-16 13:42:37 -07:00
parent 4eb3504a75
commit 15dfaeb666
2 changed files with 243 additions and 47 deletions

View File

@@ -4,6 +4,8 @@
*/
//==============================================================================
#define MULTISOCKET_TEST_PROXY 1
MultiSocket* MultiSocket::New (boost::asio::io_service& io_service, int flags)
{
return new MultiSocketType <boost::asio::ip::tcp::socket> (io_service, flags);
@@ -155,25 +157,46 @@ public:
//--------------------------------------------------------------------------
template <typename Protocol, typename ClientArg, typename ServerArg>
void run_async (ClientArg const& clientArg, ServerArg const& serverArg)
void runProxy (ClientArg const& clientArg, ServerArg const& serverArg)
{
PeerTest::run <MultiSocketDetailsType <Protocol>,
TestPeerLogicAsyncClient, TestPeerLogicAsyncServer> (clientArg, serverArg, timeoutSeconds).report (*this);
TestPeerLogicProxyClient, TestPeerLogicSyncServer> (clientArg, serverArg, timeoutSeconds).report (*this);
PeerTest::run <MultiSocketDetailsType <Protocol>,
TestPeerLogicProxyClient, TestPeerLogicAsyncServer> (clientArg, serverArg, timeoutSeconds).report (*this);
}
//--------------------------------------------------------------------------
template <typename Protocol, typename ClientArg, typename ServerArg>
void run (ClientArg const& clientArg, ServerArg const& serverArg)
{
PeerTest::run <MultiSocketDetailsType <Protocol>,
TestPeerLogicSyncClient, TestPeerLogicSyncServer> (clientArg, serverArg, timeoutSeconds).report (*this);
TestPeerLogicSyncClient, TestPeerLogicSyncServer>
(clientArg, serverArg, timeoutSeconds).report (*this);
PeerTest::run <MultiSocketDetailsType <Protocol>,
TestPeerLogicAsyncClient, TestPeerLogicSyncServer> (clientArg, serverArg, timeoutSeconds).report (*this);
TestPeerLogicAsyncClient, TestPeerLogicSyncServer>
(clientArg, serverArg, timeoutSeconds).report (*this);
PeerTest::run <MultiSocketDetailsType <Protocol>,
TestPeerLogicSyncClient, TestPeerLogicAsyncServer> (clientArg, serverArg, timeoutSeconds).report (*this);
TestPeerLogicSyncClient, TestPeerLogicAsyncServer>
(clientArg, serverArg, timeoutSeconds).report (*this);
run_async <Protocol> (clientArg, serverArg);
PeerTest::run <MultiSocketDetailsType <Protocol>,
TestPeerLogicAsyncClient, TestPeerLogicAsyncServer>
(clientArg, serverArg, timeoutSeconds).report (*this);
}
//--------------------------------------------------------------------------
template <typename Protocol>
void testProxyFlags (int extraClientFlags, int extraServerFlags)
{
check_precondition (! MultiSocket::Flag (extraClientFlags).any_set (MultiSocket::Flag::client_role | MultiSocket::Flag::server_role));
runProxy <Protocol> (MultiSocket::Flag::client_role | extraClientFlags,
MultiSocket::Flag::server_role | extraServerFlags);
}
//--------------------------------------------------------------------------
@@ -183,7 +206,8 @@ public:
{
check_precondition (! MultiSocket::Flag (extraClientFlags).any_set (MultiSocket::Flag::client_role | MultiSocket::Flag::server_role));
run <Protocol> (MultiSocket::Flag::client_role | extraClientFlags, MultiSocket::Flag::server_role | extraServerFlags);
run <Protocol> (MultiSocket::Flag::client_role | extraClientFlags,
MultiSocket::Flag::server_role | extraServerFlags);
}
template <typename Protocol>
@@ -199,6 +223,11 @@ public:
testFlags <Protocol> (MultiSocket::Flag::ssl, MultiSocket::Flag::ssl_required);
testFlags <Protocol> (0, MultiSocket::Flag::ssl);
testFlags <Protocol> (MultiSocket::Flag::ssl, MultiSocket::Flag::ssl);
testProxyFlags <Protocol> (MultiSocket::Flag::proxy, MultiSocket::Flag::proxy);
testProxyFlags <Protocol> (MultiSocket::Flag::proxy, MultiSocket::Flag::proxy | MultiSocket::Flag::ssl);
testProxyFlags <Protocol> (MultiSocket::Flag::proxy | MultiSocket::Flag::ssl, MultiSocket::Flag::proxy | MultiSocket::Flag::ssl);
testProxyFlags <Protocol> (MultiSocket::Flag::proxy | MultiSocket::Flag::ssl, MultiSocket::Flag::proxy | MultiSocket::Flag::ssl_required);
}
void runTest ()

View File

@@ -32,12 +32,6 @@ public:
typedef typename next_layer_type::lowest_layer_type lowest_layer_type;
typedef typename boost::asio::ssl::stream <next_layer_type_ref> ssl_stream_type;
// These types are used internaly
//
// The type of stream that the HandshakeDetectStream uses
typedef next_layer_type& ssl_detector_stream_type;
template <class Arg>
explicit MultiSocketType (Arg& arg, int flags = 0)
: m_flags (cleaned_flags (flags))
@@ -158,8 +152,19 @@ protected:
bool requires_handshake ()
{
if (is_client ())
{
if (m_flags.set (Flag::ssl))
return true;
return false;
}
//bassert (is_server ());
if (m_stream != nullptr)
return stream ().requires_handshake ();
return true;
}
@@ -300,7 +305,6 @@ protected:
using namespace boost;
if (asio::buffer_size (buffers) > 0)
{
bassert (asio::buffer_size (buffers) > 0);
typedef PrefilledReadStream <next_layer_type&> next_type;
typedef asio::ssl::stream <next_type> this_type;
SocketWrapper <this_type>* const socket = new SocketWrapper <this_type> (
@@ -311,16 +315,44 @@ protected:
return new_ssl_stream ();
}
// Creates a stream that detects whether or not an
// SSL handshake is being attempted.
//
Socket* new_detect_ssl_stream ()
{
SSL3DetectCallback* callback = new SSL3DetectCallback (this);
typedef HandshakeDetectStreamType <ssl_detector_stream_type, HandshakeDetectLogicSSL3> this_type;
typedef HandshakeDetectStreamType <next_layer_type&, HandshakeDetectLogicSSL3> this_type;
return new SocketWrapper <this_type> (callback, m_next_layer);
}
Socket* new_proxy_stream ()
template <typename ConstBufferSequence>
Socket* new_detect_ssl_stream (ConstBufferSequence const& buffers)
{
return nullptr;
using namespace boost;
if (asio::buffer_size (buffers) > 0)
{
SSL3DetectCallback* callback = new SSL3DetectCallback (this);
typedef HandshakeDetectStreamType <next_layer_type&, HandshakeDetectLogicSSL3> this_type;
SocketWrapper <this_type>* const socket =
new SocketWrapper <this_type> (callback, m_next_layer);
socket->this_layer ().fill (buffers);
return socket;
}
return new_detect_ssl_stream ();
}
// Creates a stream that detects whether or not a PROXY
// handshake is being sent.
//
template <typename ConstBufferSequence>
Socket* new_proxy_stream (ConstBufferSequence const& buffers)
{
// Doesn't make sense to already have data if we're expecting PROXY
bassert (boost::asio::buffer_size (buffers) == 0);
PROXYDetectCallback* callback = new PROXYDetectCallback (this);
typedef HandshakeDetectStreamType <next_layer_type&, HandshakeDetectLogicPROXY> this_type;
return new SocketWrapper <this_type> (callback, m_next_layer);
}
//--------------------------------------------------------------------------
@@ -328,7 +360,7 @@ protected:
// Returns a stream appropriate for the current settings. If the
// return value is nullptr, it means that the stream cannot be determined
// until handshake or async_handshake is called. At that point, it will
// be necessary to use new_handshake_stream
// be necessary to use new_handshake_stream.
//
Socket* new_stream ()
{
@@ -338,9 +370,8 @@ protected:
{
if (m_flags.set (Flag::proxy))
{
// sending PROXY unimplemented for now
SocketBase::pure_virtual ();
socket = nullptr;
// Client sends PROXY in the plain
socket = new_plain_stream ();
}
else if (! m_flags.any_set (
Flag::ssl))
@@ -401,13 +432,18 @@ protected:
//
// Returns true if the stream was set
//
bool set_handshake_stream (handshake_type type, error_code& ec)
bool set_handshake_stream (handshake_type type,
error_code& ec, ConstBuffers const& buffers = ConstBuffers())
{
// VFALCO can't really assert on this because what if
// there was a previous stream?
//
// We better require a handshake if this is getting called
bassert (requires_handshake ());
//bassert (requires_handshake ());
// If a stream already exists, some logic went wrong.
bassert (m_stream == nullptr);
// Can't assert on this either, for the case of going proxy->ssldetect
//bassert (m_stream == nullptr);
ec = error_code ();
@@ -423,8 +459,8 @@ protected:
// If we constructed in a server or client role and
// the handshake type doesn't match, return an error.
//
if ( (type == client && ! m_flags.set (Flag::client_role)) ||
(type == server && ! m_flags.set (Flag::server_role)))
if ( (type == client && ! is_client ()) ||
(type == server && ! is_server ()))
{
handshake_error (ec);
return false;
@@ -432,37 +468,40 @@ protected:
Socket* socket = nullptr;
if (m_flags.set (Flag::client_role))
if (is_client ())
{
// should have filtered this out already
bassert (! m_flags.set (Flag::proxy));
// can't assert on this because we dont control the send,
// only when they decide to perform the handshake.
//bassert (! m_flags.set (Flag::proxy));
if (m_flags.set (Flag::ssl))
{
socket = new_ssl_stream ();
socket = new_ssl_stream (buffers);
}
else
{
// bad
socket = new_plain_stream (buffers);
}
}
else
{
bassert (is_server ());
if (m_flags.set (Flag::proxy))
{
socket = new_proxy_stream ();
socket = new_proxy_stream (buffers);
}
else if (m_flags.set (Flag::ssl_required))
{
socket = new_ssl_stream ();
socket = new_ssl_stream (buffers);
}
else if (m_flags.set (Flag::ssl))
{
socket = new_detect_ssl_stream ();
socket = new_detect_ssl_stream (buffers);
}
else
{
// bad
socket = new_plain_stream (buffers);
}
}
@@ -475,13 +514,14 @@ protected:
//--------------------------------------------------------------------------
void on_detect_ssl (error_code& ec, ConstBuffers const& buffers, bool is_ssl)
void on_detect_ssl (HandshakeDetectLogicSSL3& logic,
error_code& ec, ConstBuffers const& buffers)
{
bassert (is_server ());
if (! ec)
{
if (is_ssl)
if (logic.success ())
{
#if BEAST_ASIO_HAS_BUFFEREDHANDSHAKE
if (useBufferedHandshake)
@@ -502,12 +542,12 @@ protected:
}
// origHandler cannot be a reference since it gets move-copied
void on_async_detect_ssl (error_code const& ec,
ConstBuffers const& buffers, ErrorCall origHandler, bool is_ssl)
void on_async_detect_ssl (HandshakeDetectLogicSSL3& logic,
error_code const& ec, ConstBuffers const& buffers, ErrorCall origHandler)
{
if (! ec)
{
if (is_ssl)
if (logic.success ())
{
m_stream = new_ssl_stream (buffers);
// Is this a continuation?
@@ -527,14 +567,14 @@ protected:
#if BEAST_ASIO_HAS_BUFFEREDHANDSHAKE
// origHandler cannot be a reference since it gets move-copied
void on_async_detect_ssl (error_code const& ec,
ConstBuffers const& buffers, TransferCall origHandler, bool is_ssl)
void on_async_detect_ssl (HandshakeDetectLogicSSL3& logic,
error_code const& ec, ConstBuffers const& buffers, TransferCall origHandler)
{
std::size_t bytes_transferred = 0;
if (! ec)
{
if (is_ssl)
if (logic.success ())
{
// Is this a continuation?
m_stream = new_ssl_stream ();
@@ -556,9 +596,98 @@ protected:
}
#endif
//--------------------------------------------------------------------------
void setProxyInfo (HandshakeDetectLogicPROXY::ProxyInfo const proxyInfo)
{
// Do something with it
}
void on_detect_proxy (HandshakeDetectLogicPROXY& logic,
error_code& ec, ConstBuffers const& buffers)
{
bassert (is_server ());
if (ec)
return;
if (! logic.success ())
{
handshake_error (ec);
return;
}
// Remove the proxy flag
m_flags = m_flags.without (Flag::proxy);
setProxyInfo (logic.getInfo ());
if (set_handshake_stream (Socket::server, ec, buffers))
{
if (stream ().requires_handshake ())
stream ().handshake (Socket::server, ec);
}
}
// origHandler cannot be a reference since it gets move-copied
void on_async_detect_proxy (HandshakeDetectLogicPROXY& logic,
error_code ec, ConstBuffers const& buffers, ErrorCall origHandler)
{
if (! ec)
{
if (! logic.success ())
handshake_error (ec);
}
if (! ec)
{
// Remove the proxy flag
m_flags = m_flags.without (Flag::proxy);
setProxyInfo (logic.getInfo ());
if (set_handshake_stream (Socket::server, ec, buffers))
{
if (stream ().requires_handshake ())
{
stream ().async_handshake (Socket::server, origHandler);
}
else
{
get_io_service ().post (CompletionCall (origHandler, ec));
}
return;
}
else
{
bassert (ec);
}
}
if (ec)
get_io_service ().post (CompletionCall (origHandler, ec));
}
#if BEAST_ASIO_HAS_BUFFEREDHANDSHAKE
// origHandler cannot be a reference since it gets move-copied
void on_async_detect_proxy (HandshakeDetectLogicPROXY& logic,
error_code const& ec, ConstBuffers const& buffers, TransferCall origHandler)
{
std::size_t bytes_transferred = 0;
if (! ec)
{
if (logic.success ())
{
}
}
}
#endif
//--------------------------------------------------------------------------
//
// HandshakeDetectStream calback
// SSL3 Callback
//
class SSL3DetectCallback
: public HandshakeDetectStream <HandshakeDetectLogicSSL3>::Callback
@@ -574,20 +703,58 @@ protected:
void on_detect (LogicType& logic, error_code& ec,
ConstBuffers const& buffers)
{
m_owner->on_detect_ssl (ec, buffers, logic.success ());
m_owner->on_detect_ssl (logic, ec, buffers);
}
void on_async_detect (LogicType& logic, error_code const& ec,
ConstBuffers const& buffers, ErrorCall const& origHandler)
{
m_owner->on_async_detect_ssl (ec, buffers, origHandler, logic.success ());
m_owner->on_async_detect_ssl (logic, ec, buffers, origHandler);
}
#if BEAST_ASIO_HAS_BUFFEREDHANDSHAKE
void on_async_detect (LogicType& logic, error_code const& ec,
ConstBuffers const& buffers, TransferCall const& origHandler)
{
m_owner->on_async_detect_ssl (ec, buffers, origHandler, logic.success ());
m_owner->on_async_detect_ssl (logic, ec, buffers, origHandler);
}
#endif
private:
MultiSocketType <StreamSocket>* m_owner;
};
//--------------------------------------------------------------------------
//
// PROXY callback
//
class PROXYDetectCallback
: public HandshakeDetectStream <HandshakeDetectLogicPROXY>::Callback
{
public:
typedef HandshakeDetectLogicPROXY LogicType;
explicit PROXYDetectCallback (MultiSocketType <StreamSocket>* owner)
: m_owner (owner)
{
}
void on_detect (LogicType& logic, error_code& ec, ConstBuffers const& buffers)
{
m_owner->on_detect_proxy (logic, ec, buffers);
}
void on_async_detect (LogicType& logic, error_code const& ec,
ConstBuffers const& buffers, ErrorCall const& origHandler)
{
m_owner->on_async_detect_proxy (logic, ec, buffers, origHandler);
}
#if BEAST_ASIO_HAS_BUFFEREDHANDSHAKE
void on_async_detect (LogicType& logic, error_code const& ec,
ConstBuffers const& buffers, TransferCall const& origHandler)
{
m_owner->on_async_detect_proxy (logic, ec, buffers, origHandler);
}
#endif