Use handshake detect stream in MultiSocket

This commit is contained in:
Vinnie Falco
2013-08-16 13:42:37 -07:00
parent 4eb3504a75
commit 15dfaeb666
2 changed files with 243 additions and 47 deletions

View File

@@ -4,6 +4,8 @@
*/ */
//============================================================================== //==============================================================================
#define MULTISOCKET_TEST_PROXY 1
MultiSocket* MultiSocket::New (boost::asio::io_service& io_service, int flags) MultiSocket* MultiSocket::New (boost::asio::io_service& io_service, int flags)
{ {
return new MultiSocketType <boost::asio::ip::tcp::socket> (io_service, flags); return new MultiSocketType <boost::asio::ip::tcp::socket> (io_service, flags);
@@ -155,25 +157,46 @@ public:
//-------------------------------------------------------------------------- //--------------------------------------------------------------------------
template <typename Protocol, typename ClientArg, typename ServerArg> template <typename Protocol, typename ClientArg, typename ServerArg>
void run_async (ClientArg const& clientArg, ServerArg const& serverArg) void runProxy (ClientArg const& clientArg, ServerArg const& serverArg)
{ {
PeerTest::run <MultiSocketDetailsType <Protocol>, PeerTest::run <MultiSocketDetailsType <Protocol>,
TestPeerLogicAsyncClient, TestPeerLogicAsyncServer> (clientArg, serverArg, timeoutSeconds).report (*this); TestPeerLogicProxyClient, TestPeerLogicSyncServer> (clientArg, serverArg, timeoutSeconds).report (*this);
PeerTest::run <MultiSocketDetailsType <Protocol>,
TestPeerLogicProxyClient, TestPeerLogicAsyncServer> (clientArg, serverArg, timeoutSeconds).report (*this);
} }
//--------------------------------------------------------------------------
template <typename Protocol, typename ClientArg, typename ServerArg> template <typename Protocol, typename ClientArg, typename ServerArg>
void run (ClientArg const& clientArg, ServerArg const& serverArg) void run (ClientArg const& clientArg, ServerArg const& serverArg)
{ {
PeerTest::run <MultiSocketDetailsType <Protocol>, PeerTest::run <MultiSocketDetailsType <Protocol>,
TestPeerLogicSyncClient, TestPeerLogicSyncServer> (clientArg, serverArg, timeoutSeconds).report (*this); TestPeerLogicSyncClient, TestPeerLogicSyncServer>
(clientArg, serverArg, timeoutSeconds).report (*this);
PeerTest::run <MultiSocketDetailsType <Protocol>, PeerTest::run <MultiSocketDetailsType <Protocol>,
TestPeerLogicAsyncClient, TestPeerLogicSyncServer> (clientArg, serverArg, timeoutSeconds).report (*this); TestPeerLogicAsyncClient, TestPeerLogicSyncServer>
(clientArg, serverArg, timeoutSeconds).report (*this);
PeerTest::run <MultiSocketDetailsType <Protocol>, PeerTest::run <MultiSocketDetailsType <Protocol>,
TestPeerLogicSyncClient, TestPeerLogicAsyncServer> (clientArg, serverArg, timeoutSeconds).report (*this); TestPeerLogicSyncClient, TestPeerLogicAsyncServer>
(clientArg, serverArg, timeoutSeconds).report (*this);
run_async <Protocol> (clientArg, serverArg); PeerTest::run <MultiSocketDetailsType <Protocol>,
TestPeerLogicAsyncClient, TestPeerLogicAsyncServer>
(clientArg, serverArg, timeoutSeconds).report (*this);
}
//--------------------------------------------------------------------------
template <typename Protocol>
void testProxyFlags (int extraClientFlags, int extraServerFlags)
{
check_precondition (! MultiSocket::Flag (extraClientFlags).any_set (MultiSocket::Flag::client_role | MultiSocket::Flag::server_role));
runProxy <Protocol> (MultiSocket::Flag::client_role | extraClientFlags,
MultiSocket::Flag::server_role | extraServerFlags);
} }
//-------------------------------------------------------------------------- //--------------------------------------------------------------------------
@@ -183,7 +206,8 @@ public:
{ {
check_precondition (! MultiSocket::Flag (extraClientFlags).any_set (MultiSocket::Flag::client_role | MultiSocket::Flag::server_role)); check_precondition (! MultiSocket::Flag (extraClientFlags).any_set (MultiSocket::Flag::client_role | MultiSocket::Flag::server_role));
run <Protocol> (MultiSocket::Flag::client_role | extraClientFlags, MultiSocket::Flag::server_role | extraServerFlags); run <Protocol> (MultiSocket::Flag::client_role | extraClientFlags,
MultiSocket::Flag::server_role | extraServerFlags);
} }
template <typename Protocol> template <typename Protocol>
@@ -199,6 +223,11 @@ public:
testFlags <Protocol> (MultiSocket::Flag::ssl, MultiSocket::Flag::ssl_required); testFlags <Protocol> (MultiSocket::Flag::ssl, MultiSocket::Flag::ssl_required);
testFlags <Protocol> (0, MultiSocket::Flag::ssl); testFlags <Protocol> (0, MultiSocket::Flag::ssl);
testFlags <Protocol> (MultiSocket::Flag::ssl, MultiSocket::Flag::ssl); testFlags <Protocol> (MultiSocket::Flag::ssl, MultiSocket::Flag::ssl);
testProxyFlags <Protocol> (MultiSocket::Flag::proxy, MultiSocket::Flag::proxy);
testProxyFlags <Protocol> (MultiSocket::Flag::proxy, MultiSocket::Flag::proxy | MultiSocket::Flag::ssl);
testProxyFlags <Protocol> (MultiSocket::Flag::proxy | MultiSocket::Flag::ssl, MultiSocket::Flag::proxy | MultiSocket::Flag::ssl);
testProxyFlags <Protocol> (MultiSocket::Flag::proxy | MultiSocket::Flag::ssl, MultiSocket::Flag::proxy | MultiSocket::Flag::ssl_required);
} }
void runTest () void runTest ()

View File

@@ -32,12 +32,6 @@ public:
typedef typename next_layer_type::lowest_layer_type lowest_layer_type; typedef typename next_layer_type::lowest_layer_type lowest_layer_type;
typedef typename boost::asio::ssl::stream <next_layer_type_ref> ssl_stream_type; typedef typename boost::asio::ssl::stream <next_layer_type_ref> ssl_stream_type;
// These types are used internaly
//
// The type of stream that the HandshakeDetectStream uses
typedef next_layer_type& ssl_detector_stream_type;
template <class Arg> template <class Arg>
explicit MultiSocketType (Arg& arg, int flags = 0) explicit MultiSocketType (Arg& arg, int flags = 0)
: m_flags (cleaned_flags (flags)) : m_flags (cleaned_flags (flags))
@@ -158,8 +152,19 @@ protected:
bool requires_handshake () bool requires_handshake ()
{ {
if (is_client ())
{
if (m_flags.set (Flag::ssl))
return true;
return false;
}
//bassert (is_server ());
if (m_stream != nullptr) if (m_stream != nullptr)
return stream ().requires_handshake (); return stream ().requires_handshake ();
return true; return true;
} }
@@ -300,7 +305,6 @@ protected:
using namespace boost; using namespace boost;
if (asio::buffer_size (buffers) > 0) if (asio::buffer_size (buffers) > 0)
{ {
bassert (asio::buffer_size (buffers) > 0);
typedef PrefilledReadStream <next_layer_type&> next_type; typedef PrefilledReadStream <next_layer_type&> next_type;
typedef asio::ssl::stream <next_type> this_type; typedef asio::ssl::stream <next_type> this_type;
SocketWrapper <this_type>* const socket = new SocketWrapper <this_type> ( SocketWrapper <this_type>* const socket = new SocketWrapper <this_type> (
@@ -311,16 +315,44 @@ protected:
return new_ssl_stream (); return new_ssl_stream ();
} }
// Creates a stream that detects whether or not an
// SSL handshake is being attempted.
//
Socket* new_detect_ssl_stream () Socket* new_detect_ssl_stream ()
{ {
SSL3DetectCallback* callback = new SSL3DetectCallback (this); SSL3DetectCallback* callback = new SSL3DetectCallback (this);
typedef HandshakeDetectStreamType <ssl_detector_stream_type, HandshakeDetectLogicSSL3> this_type; typedef HandshakeDetectStreamType <next_layer_type&, HandshakeDetectLogicSSL3> this_type;
return new SocketWrapper <this_type> (callback, m_next_layer); return new SocketWrapper <this_type> (callback, m_next_layer);
} }
Socket* new_proxy_stream () template <typename ConstBufferSequence>
Socket* new_detect_ssl_stream (ConstBufferSequence const& buffers)
{ {
return nullptr; using namespace boost;
if (asio::buffer_size (buffers) > 0)
{
SSL3DetectCallback* callback = new SSL3DetectCallback (this);
typedef HandshakeDetectStreamType <next_layer_type&, HandshakeDetectLogicSSL3> this_type;
SocketWrapper <this_type>* const socket =
new SocketWrapper <this_type> (callback, m_next_layer);
socket->this_layer ().fill (buffers);
return socket;
}
return new_detect_ssl_stream ();
}
// Creates a stream that detects whether or not a PROXY
// handshake is being sent.
//
template <typename ConstBufferSequence>
Socket* new_proxy_stream (ConstBufferSequence const& buffers)
{
// Doesn't make sense to already have data if we're expecting PROXY
bassert (boost::asio::buffer_size (buffers) == 0);
PROXYDetectCallback* callback = new PROXYDetectCallback (this);
typedef HandshakeDetectStreamType <next_layer_type&, HandshakeDetectLogicPROXY> this_type;
return new SocketWrapper <this_type> (callback, m_next_layer);
} }
//-------------------------------------------------------------------------- //--------------------------------------------------------------------------
@@ -328,7 +360,7 @@ protected:
// Returns a stream appropriate for the current settings. If the // Returns a stream appropriate for the current settings. If the
// return value is nullptr, it means that the stream cannot be determined // return value is nullptr, it means that the stream cannot be determined
// until handshake or async_handshake is called. At that point, it will // until handshake or async_handshake is called. At that point, it will
// be necessary to use new_handshake_stream // be necessary to use new_handshake_stream.
// //
Socket* new_stream () Socket* new_stream ()
{ {
@@ -338,9 +370,8 @@ protected:
{ {
if (m_flags.set (Flag::proxy)) if (m_flags.set (Flag::proxy))
{ {
// sending PROXY unimplemented for now // Client sends PROXY in the plain
SocketBase::pure_virtual (); socket = new_plain_stream ();
socket = nullptr;
} }
else if (! m_flags.any_set ( else if (! m_flags.any_set (
Flag::ssl)) Flag::ssl))
@@ -401,13 +432,18 @@ protected:
// //
// Returns true if the stream was set // Returns true if the stream was set
// //
bool set_handshake_stream (handshake_type type, error_code& ec) bool set_handshake_stream (handshake_type type,
error_code& ec, ConstBuffers const& buffers = ConstBuffers())
{ {
// VFALCO can't really assert on this because what if
// there was a previous stream?
//
// We better require a handshake if this is getting called // We better require a handshake if this is getting called
bassert (requires_handshake ()); //bassert (requires_handshake ());
// If a stream already exists, some logic went wrong. // If a stream already exists, some logic went wrong.
bassert (m_stream == nullptr); // Can't assert on this either, for the case of going proxy->ssldetect
//bassert (m_stream == nullptr);
ec = error_code (); ec = error_code ();
@@ -423,8 +459,8 @@ protected:
// If we constructed in a server or client role and // If we constructed in a server or client role and
// the handshake type doesn't match, return an error. // the handshake type doesn't match, return an error.
// //
if ( (type == client && ! m_flags.set (Flag::client_role)) || if ( (type == client && ! is_client ()) ||
(type == server && ! m_flags.set (Flag::server_role))) (type == server && ! is_server ()))
{ {
handshake_error (ec); handshake_error (ec);
return false; return false;
@@ -432,37 +468,40 @@ protected:
Socket* socket = nullptr; Socket* socket = nullptr;
if (m_flags.set (Flag::client_role)) if (is_client ())
{ {
// should have filtered this out already // can't assert on this because we dont control the send,
bassert (! m_flags.set (Flag::proxy)); // only when they decide to perform the handshake.
//bassert (! m_flags.set (Flag::proxy));
if (m_flags.set (Flag::ssl)) if (m_flags.set (Flag::ssl))
{ {
socket = new_ssl_stream (); socket = new_ssl_stream (buffers);
} }
else else
{ {
// bad socket = new_plain_stream (buffers);
} }
} }
else else
{ {
bassert (is_server ());
if (m_flags.set (Flag::proxy)) if (m_flags.set (Flag::proxy))
{ {
socket = new_proxy_stream (); socket = new_proxy_stream (buffers);
} }
else if (m_flags.set (Flag::ssl_required)) else if (m_flags.set (Flag::ssl_required))
{ {
socket = new_ssl_stream (); socket = new_ssl_stream (buffers);
} }
else if (m_flags.set (Flag::ssl)) else if (m_flags.set (Flag::ssl))
{ {
socket = new_detect_ssl_stream (); socket = new_detect_ssl_stream (buffers);
} }
else else
{ {
// bad socket = new_plain_stream (buffers);
} }
} }
@@ -475,13 +514,14 @@ protected:
//-------------------------------------------------------------------------- //--------------------------------------------------------------------------
void on_detect_ssl (error_code& ec, ConstBuffers const& buffers, bool is_ssl) void on_detect_ssl (HandshakeDetectLogicSSL3& logic,
error_code& ec, ConstBuffers const& buffers)
{ {
bassert (is_server ()); bassert (is_server ());
if (! ec) if (! ec)
{ {
if (is_ssl) if (logic.success ())
{ {
#if BEAST_ASIO_HAS_BUFFEREDHANDSHAKE #if BEAST_ASIO_HAS_BUFFEREDHANDSHAKE
if (useBufferedHandshake) if (useBufferedHandshake)
@@ -502,12 +542,12 @@ protected:
} }
// origHandler cannot be a reference since it gets move-copied // origHandler cannot be a reference since it gets move-copied
void on_async_detect_ssl (error_code const& ec, void on_async_detect_ssl (HandshakeDetectLogicSSL3& logic,
ConstBuffers const& buffers, ErrorCall origHandler, bool is_ssl) error_code const& ec, ConstBuffers const& buffers, ErrorCall origHandler)
{ {
if (! ec) if (! ec)
{ {
if (is_ssl) if (logic.success ())
{ {
m_stream = new_ssl_stream (buffers); m_stream = new_ssl_stream (buffers);
// Is this a continuation? // Is this a continuation?
@@ -527,14 +567,14 @@ protected:
#if BEAST_ASIO_HAS_BUFFEREDHANDSHAKE #if BEAST_ASIO_HAS_BUFFEREDHANDSHAKE
// origHandler cannot be a reference since it gets move-copied // origHandler cannot be a reference since it gets move-copied
void on_async_detect_ssl (error_code const& ec, void on_async_detect_ssl (HandshakeDetectLogicSSL3& logic,
ConstBuffers const& buffers, TransferCall origHandler, bool is_ssl) error_code const& ec, ConstBuffers const& buffers, TransferCall origHandler)
{ {
std::size_t bytes_transferred = 0; std::size_t bytes_transferred = 0;
if (! ec) if (! ec)
{ {
if (is_ssl) if (logic.success ())
{ {
// Is this a continuation? // Is this a continuation?
m_stream = new_ssl_stream (); m_stream = new_ssl_stream ();
@@ -556,9 +596,98 @@ protected:
} }
#endif #endif
//--------------------------------------------------------------------------
void setProxyInfo (HandshakeDetectLogicPROXY::ProxyInfo const proxyInfo)
{
// Do something with it
}
void on_detect_proxy (HandshakeDetectLogicPROXY& logic,
error_code& ec, ConstBuffers const& buffers)
{
bassert (is_server ());
if (ec)
return;
if (! logic.success ())
{
handshake_error (ec);
return;
}
// Remove the proxy flag
m_flags = m_flags.without (Flag::proxy);
setProxyInfo (logic.getInfo ());
if (set_handshake_stream (Socket::server, ec, buffers))
{
if (stream ().requires_handshake ())
stream ().handshake (Socket::server, ec);
}
}
// origHandler cannot be a reference since it gets move-copied
void on_async_detect_proxy (HandshakeDetectLogicPROXY& logic,
error_code ec, ConstBuffers const& buffers, ErrorCall origHandler)
{
if (! ec)
{
if (! logic.success ())
handshake_error (ec);
}
if (! ec)
{
// Remove the proxy flag
m_flags = m_flags.without (Flag::proxy);
setProxyInfo (logic.getInfo ());
if (set_handshake_stream (Socket::server, ec, buffers))
{
if (stream ().requires_handshake ())
{
stream ().async_handshake (Socket::server, origHandler);
}
else
{
get_io_service ().post (CompletionCall (origHandler, ec));
}
return;
}
else
{
bassert (ec);
}
}
if (ec)
get_io_service ().post (CompletionCall (origHandler, ec));
}
#if BEAST_ASIO_HAS_BUFFEREDHANDSHAKE
// origHandler cannot be a reference since it gets move-copied
void on_async_detect_proxy (HandshakeDetectLogicPROXY& logic,
error_code const& ec, ConstBuffers const& buffers, TransferCall origHandler)
{
std::size_t bytes_transferred = 0;
if (! ec)
{
if (logic.success ())
{
}
}
}
#endif
//-------------------------------------------------------------------------- //--------------------------------------------------------------------------
// //
// HandshakeDetectStream calback // SSL3 Callback
// //
class SSL3DetectCallback class SSL3DetectCallback
: public HandshakeDetectStream <HandshakeDetectLogicSSL3>::Callback : public HandshakeDetectStream <HandshakeDetectLogicSSL3>::Callback
@@ -574,20 +703,58 @@ protected:
void on_detect (LogicType& logic, error_code& ec, void on_detect (LogicType& logic, error_code& ec,
ConstBuffers const& buffers) ConstBuffers const& buffers)
{ {
m_owner->on_detect_ssl (ec, buffers, logic.success ()); m_owner->on_detect_ssl (logic, ec, buffers);
} }
void on_async_detect (LogicType& logic, error_code const& ec, void on_async_detect (LogicType& logic, error_code const& ec,
ConstBuffers const& buffers, ErrorCall const& origHandler) ConstBuffers const& buffers, ErrorCall const& origHandler)
{ {
m_owner->on_async_detect_ssl (ec, buffers, origHandler, logic.success ()); m_owner->on_async_detect_ssl (logic, ec, buffers, origHandler);
} }
#if BEAST_ASIO_HAS_BUFFEREDHANDSHAKE #if BEAST_ASIO_HAS_BUFFEREDHANDSHAKE
void on_async_detect (LogicType& logic, error_code const& ec, void on_async_detect (LogicType& logic, error_code const& ec,
ConstBuffers const& buffers, TransferCall const& origHandler) ConstBuffers const& buffers, TransferCall const& origHandler)
{ {
m_owner->on_async_detect_ssl (ec, buffers, origHandler, logic.success ()); m_owner->on_async_detect_ssl (logic, ec, buffers, origHandler);
}
#endif
private:
MultiSocketType <StreamSocket>* m_owner;
};
//--------------------------------------------------------------------------
//
// PROXY callback
//
class PROXYDetectCallback
: public HandshakeDetectStream <HandshakeDetectLogicPROXY>::Callback
{
public:
typedef HandshakeDetectLogicPROXY LogicType;
explicit PROXYDetectCallback (MultiSocketType <StreamSocket>* owner)
: m_owner (owner)
{
}
void on_detect (LogicType& logic, error_code& ec, ConstBuffers const& buffers)
{
m_owner->on_detect_proxy (logic, ec, buffers);
}
void on_async_detect (LogicType& logic, error_code const& ec,
ConstBuffers const& buffers, ErrorCall const& origHandler)
{
m_owner->on_async_detect_proxy (logic, ec, buffers, origHandler);
}
#if BEAST_ASIO_HAS_BUFFEREDHANDSHAKE
void on_async_detect (LogicType& logic, error_code const& ec,
ConstBuffers const& buffers, TransferCall const& origHandler)
{
m_owner->on_async_detect_proxy (logic, ec, buffers, origHandler);
} }
#endif #endif