Generic HandshakeDetectStream and HandshakeDetectLogic

This commit is contained in:
Vinnie Falco
2013-08-16 08:01:59 -07:00
parent 0feccfd7b0
commit bb941354ce
18 changed files with 787 additions and 945 deletions

View File

@@ -7,296 +7,6 @@
#ifndef RIPPLE_MULTISOCKETTYPE_H_INCLUDED
#define RIPPLE_MULTISOCKETTYPE_H_INCLUDED
//------------------------------------------------------------------------------
// should really start using this
namespace detail
{
}
//------------------------------------------------------------------------------
class ssl_detector_base
{
protected:
typedef boost::system::error_code error_code;
public:
/** Called when the state is known.
This could be called from any thread, most likely an io_service
thread but don't rely on that.
The Callback must be allocated via operator new.
*/
struct Callback
{
virtual ~Callback () { }
/** Called for synchronous ssl detection.
Note that the storage for the buffers passed to the
callback is owned by the detector class and becomes
invalid when the detector class is destroyed, which is
a common thing to do from inside your callback.
@param ec A modifiable error code that becomes the return
value of handshake.
@param buffers The bytes that were read in.
@param is_ssl True if the sequence is an ssl handshake.
*/
virtual void on_detect (error_code& ec,
ConstBuffers const& buffers, bool is_ssl) = 0;
virtual void on_async_detect (error_code const& ec,
ConstBuffers const& buffers, ErrorCall const& origHandler,
bool is_ssl) = 0;
#if BEAST_ASIO_HAS_BUFFEREDHANDSHAKE
virtual void on_async_detect (error_code const& ec,
ConstBuffers const& buffers, TransferCall const& origHandler,
bool is_ssl) = 0;
#endif
};
};
/** A stream that can detect an SSL handshake.
*/
template <typename Stream>
class ssl_detector
: public ssl_detector_base
, public boost::asio::ssl::stream_base
, public boost::asio::socket_base
{
private:
typedef ssl_detector <Stream> this_type;
typedef boost::asio::streambuf buffer_type;
typedef typename boost::remove_reference <Stream>::type stream_type;
public:
/** This takes ownership of the callback.
The callback must be allocated with operator new.
*/
template <typename Arg>
ssl_detector (Callback* callback, Arg& arg)
: m_callback (callback)
, m_next_layer (arg)
, m_stream (m_next_layer)
{
}
// basic_io_object
boost::asio::io_service& get_io_service ()
{
return m_next_layer.get_io_service ();
}
// basic_socket
typedef typename stream_type::protocol_type protocol_type;
typedef typename stream_type::lowest_layer_type lowest_layer_type;
lowest_layer_type& lowest_layer ()
{
return m_next_layer.lowest_layer ();
}
lowest_layer_type const& lowest_layer () const
{
return m_next_layer.lowest_layer ();
}
// ssl::stream
error_code handshake (handshake_type type, error_code& ec)
{
return do_handshake (type, ec, ConstBuffers ());
}
template <typename HandshakeHandler>
BOOST_ASIO_INITFN_RESULT_TYPE(HandshakeHandler, void (error_code))
async_handshake (handshake_type type, BOOST_ASIO_MOVE_ARG(HandshakeHandler) handler)
{
#if BEAST_ASIO_HAS_FUTURE_RETURNS
boost::asio::detail::async_result_init<
ErrorCall, void (error_code)> init(
BOOST_ASIO_MOVE_CAST(HandshakeHandler)(handler));
// init.handler is copied
m_origHandler = ErrorCall (HandshakeHandler(init.handler));
bassert (m_origBufferedHandler.isNull ());
async_do_handshake (type, ConstBuffers ());
return init.result.get();
#else
m_origHandler = ErrorCall (handler);
bassert (m_origBufferedHandler.isNull ());
async_do_handshake (type, ConstBuffers ());
#endif
}
#if BEAST_ASIO_HAS_BUFFEREDHANDSHAKE
template <typename ConstBufferSequence>
error_code handshake (handshake_type type,
const ConstBufferSequence& buffers, error_code& ec)
{
return do_handshake (type, ec, ConstBuffers (buffers));;
}
template <typename ConstBufferSequence, typename BufferedHandshakeHandler>
BOOST_ASIO_INITFN_RESULT_TYPE(BufferedHandshakeHandler, void (error_code, std::size_t))
async_handshake(handshake_type type, const ConstBufferSequence& buffers,
BOOST_ASIO_MOVE_ARG(BufferedHandshakeHandler) handler)
{
#if BEAST_ASIO_HAS_FUTURE_RETURNS
boost::asio::detail::async_result_init<
BufferedHandshakeHandler, void (error_code, std::size_t)> init(
BOOST_ASIO_MOVE_CAST(BufferedHandshakeHandler)(handler));
// init.handler is copied
m_origBufferedHandler = TransferCall (BufferedHandshakeHandler(init.handler));
bassert (m_origHandler.isNull ());
async_do_handshake (type, ConstBuffers (buffers));
return init.result.get();
#else
m_origBufferedHandler = TransferCall (handler);
bassert (m_origHandler.isNull ());
async_do_handshake (type, ConstBuffers (buffers));
#endif
}
#endif
//--------------------------------------------------------------------------
error_code do_handshake (handshake_type, error_code& ec, ConstBuffers const& buffers)
{
ec = error_code ();
// Transfer caller data to our buffer.
m_buffer.commit (boost::asio::buffer_copy (m_buffer.prepare (
boost::asio::buffer_size (buffers)), buffers));
do
{
std::size_t const available = m_buffer.size ();
std::size_t const needed = m_detector.max_needed ();
if (available < needed)
{
buffer_type::mutable_buffers_type buffers (
m_buffer.prepare (needed - available));
m_buffer.commit (m_next_layer.read_some (buffers, ec));
}
if (! ec)
{
m_detector.analyze (m_buffer.data ());
if (m_detector.finished ())
{
m_callback->on_detect (ec,
ConstBuffers (m_buffer.data ()),
m_detector.success ());
break;
}
// If this fails it means we will never finish
check_postcondition (available < needed);
}
}
while (! ec);
return ec;
}
//--------------------------------------------------------------------------
void async_do_handshake (handshake_type type, ConstBuffers const& buffers)
{
// Transfer caller data to our buffer.
std::size_t const bytes_transferred (boost::asio::buffer_copy (
m_buffer.prepare (boost::asio::buffer_size (buffers)), buffers));
// bootstrap the asynchronous loop
on_async_read_some (error_code (), bytes_transferred);
}
// asynchronous version of the synchronous loop found in handshake ()
//
void on_async_read_some (error_code const& ec, std::size_t bytes_transferred)
{
if (! ec)
{
m_buffer.commit (bytes_transferred);
std::size_t const available = m_buffer.size ();
std::size_t const needed = m_detector.max_needed ();
m_detector.analyze (m_buffer.data ());
if (m_detector.finished ())
{
#if BEAST_ASIO_HAS_BUFFEREDHANDSHAKE
if (! m_origBufferedHandler.isNull ())
{
bassert (m_origHandler.isNull ());
// continuation?
m_callback->on_async_detect (ec, ConstBuffers (m_buffer.data ()),
m_origBufferedHandler, m_detector.success ());
return;
}
#endif
bassert (! m_origHandler.isNull ())
// continuation?
m_callback->on_async_detect (ec,
ConstBuffers (m_buffer.data ()), m_origHandler,
m_detector.success ());
return;
}
// If this fails it means we will never finish
check_postcondition (available < needed);
buffer_type::mutable_buffers_type buffers (m_buffer.prepare (
needed - available));
// need a continuation hook here?
m_next_layer.async_read_some (buffers, boost::bind (
&this_type::on_async_read_some, this, boost::asio::placeholders::error,
boost::asio::placeholders::bytes_transferred));
return;
}
#if BEAST_ASIO_HAS_BUFFEREDHANDSHAKE
if (! m_origBufferedHandler.isNull ())
{
bassert (m_origHandler.isNull ());
// continuation?
m_callback->on_async_detect (ec, ConstBuffers (m_buffer.data ()),
m_origBufferedHandler, false);
return;
}
#endif
bassert (! m_origHandler.isNull ())
// continuation?
m_callback->on_async_detect (ec,
ConstBuffers (m_buffer.data ()), m_origHandler,
false);
}
private:
ScopedPointer <Callback> m_callback;
Stream m_next_layer;
buffer_type m_buffer;
boost::asio::buffered_read_stream <stream_type&> m_stream;
HandshakeDetectorType <SSL3> m_detector;
ErrorCall m_origHandler;
TransferCall m_origBufferedHandler;
};
//------------------------------------------------------------------------------
/** Template for producing instances of MultiSocket
*/
template <class StreamSocket>
@@ -325,7 +35,7 @@ public:
// These types are used internaly
//
// The type of stream that the ssl_detector uses
// The type of stream that the HandshakeDetectStream uses
typedef next_layer_type& ssl_detector_stream_type;
template <class Arg>
@@ -603,8 +313,8 @@ protected:
Socket* new_detect_ssl_stream ()
{
ssl_detector_callback* callback = new ssl_detector_callback (this);
typedef ssl_detector <ssl_detector_stream_type> this_type;
SSL3DetectCallback* callback = new SSL3DetectCallback (this);
typedef HandshakeDetectStreamType <ssl_detector_stream_type, HandshakeDetectLogicSSL3> this_type;
return new SocketWrapper <this_type> (callback, m_next_layer);
}
@@ -848,35 +558,36 @@ protected:
//--------------------------------------------------------------------------
//
// ssl_detector calback
// HandshakeDetectStream calback
//
class ssl_detector_callback : public ssl_detector_base::Callback
class SSL3DetectCallback
: public HandshakeDetectStream <HandshakeDetectLogicSSL3>::Callback
{
public:
explicit ssl_detector_callback (MultiSocketType <StreamSocket>* owner)
typedef HandshakeDetectLogicSSL3 LogicType;
explicit SSL3DetectCallback (MultiSocketType <StreamSocket>* owner)
: m_owner (owner)
{
}
void on_detect (error_code& ec,
ConstBuffers const& buffers, bool is_ssl)
void on_detect (LogicType& logic, error_code& ec,
ConstBuffers const& buffers)
{
m_owner->on_detect_ssl (ec, buffers, is_ssl);
m_owner->on_detect_ssl (ec, buffers, logic.success ());
}
void on_async_detect (error_code const& ec,
ConstBuffers const& buffers, ErrorCall const& origHandler,
bool is_ssl)
void on_async_detect (LogicType& logic, error_code const& ec,
ConstBuffers const& buffers, ErrorCall const& origHandler)
{
m_owner->on_async_detect_ssl (ec, buffers, origHandler, is_ssl);
m_owner->on_async_detect_ssl (ec, buffers, origHandler, logic.success ());
}
#if BEAST_ASIO_HAS_BUFFEREDHANDSHAKE
void on_async_detect (error_code const& ec,
ConstBuffers const& buffers, TransferCall const& origHandler,
bool is_ssl)
void on_async_detect (LogicType& logic, error_code const& ec,
ConstBuffers const& buffers, TransferCall const& origHandler)
{
m_owner->on_async_detect_ssl (ec, buffers, origHandler, is_ssl);
m_owner->on_async_detect_ssl (ec, buffers, origHandler, logic.success ());
}
#endif