Refactor handshake parsing and logic classes

This commit is contained in:
Vinnie Falco
2013-08-16 13:42:20 -07:00
parent 6108aca055
commit 4615297b00
15 changed files with 747 additions and 607 deletions

View File

@@ -77,10 +77,11 @@
<ClInclude Include="..\..\modules\beast_asio\basics\beast_TransferCall.h" />
<ClInclude Include="..\..\modules\beast_asio\beast_asio.h" />
<ClInclude Include="..\..\modules\beast_asio\handshake\beast_HandshakeDetectLogic.h" />
<ClInclude Include="..\..\modules\beast_asio\handshake\beast_HandshakeDetectLogicPROXY.h" />
<ClInclude Include="..\..\modules\beast_asio\handshake\beast_HandshakeDetectLogicSSL2.h" />
<ClInclude Include="..\..\modules\beast_asio\handshake\beast_HandshakeDetectLogicSSL3.h" />
<ClInclude Include="..\..\modules\beast_asio\handshake\beast_HandshakeDetectStream.h" />
<ClInclude Include="..\..\modules\beast_asio\handshake\beast_ProxyHandshake.h" />
<ClInclude Include="..\..\modules\beast_asio\handshake\beast_InputParser.h" />
<ClInclude Include="..\..\modules\beast_asio\sockets\beast_Socket.h" />
<ClInclude Include="..\..\modules\beast_asio\sockets\beast_SocketBase.h" />
<ClInclude Include="..\..\modules\beast_asio\sockets\beast_SocketWrapper.h" />
@@ -309,7 +310,7 @@
<ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|x64'">true</ExcludedFromBuild>
</ClCompile>
<ClCompile Include="..\..\modules\beast_asio\beast_asio.cpp" />
<ClCompile Include="..\..\modules\beast_asio\handshake\beast_ProxyHandshake.cpp">
<ClCompile Include="..\..\modules\beast_asio\handshake\beast_HandshakeDetectLogicPROXY.cpp">
<ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Debug|Win32'">true</ExcludedFromBuild>
<ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Release|Win32'">true</ExcludedFromBuild>
<ExcludedFromBuild Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">true</ExcludedFromBuild>

View File

@@ -836,9 +836,6 @@
<ClInclude Include="..\..\modules\beast_asio\streams\beast_PrefilledReadStream.h">
<Filter>beast_asio\streams</Filter>
</ClInclude>
<ClInclude Include="..\..\modules\beast_asio\handshake\beast_ProxyHandshake.h">
<Filter>beast_asio\handshake</Filter>
</ClInclude>
<ClInclude Include="..\..\modules\beast_asio\handshake\beast_HandshakeDetectLogic.h">
<Filter>beast_asio\handshake</Filter>
</ClInclude>
@@ -854,6 +851,12 @@
<ClInclude Include="..\..\modules\beast_asio\handshake\beast_HandshakeDetectStream.h">
<Filter>beast_asio\handshake</Filter>
</ClInclude>
<ClInclude Include="..\..\modules\beast_asio\handshake\beast_HandshakeDetectLogicPROXY.h">
<Filter>beast_asio\handshake</Filter>
</ClInclude>
<ClInclude Include="..\..\modules\beast_asio\handshake\beast_InputParser.h">
<Filter>beast_asio\handshake</Filter>
</ClInclude>
</ItemGroup>
<ItemGroup>
<ClCompile Include="..\..\modules\beast_core\beast_core.cpp">
@@ -1312,7 +1315,7 @@
<ClCompile Include="..\..\modules\beast_asio\basics\beast_PeerRole.cpp">
<Filter>beast_asio\basics</Filter>
</ClCompile>
<ClCompile Include="..\..\modules\beast_asio\handshake\beast_ProxyHandshake.cpp">
<ClCompile Include="..\..\modules\beast_asio\handshake\beast_HandshakeDetectLogicPROXY.cpp">
<Filter>beast_asio\handshake</Filter>
</ClCompile>
</ItemGroup>

View File

@@ -24,68 +24,88 @@
This provides a convenient interface for doing a bytewise
verification/reject test on a handshake protocol.
*/
template <int Bytes>
struct FixedInputBuffer
/** @{ */
class FixedInputBuffer
{
template <typename ConstBufferSequence>
explicit FixedInputBuffer (ConstBufferSequence const& buffer)
: m_buffer (boost::asio::buffer (m_storage))
, m_size (boost::asio::buffer_copy (m_buffer, buffer))
, m_data (boost::asio::buffer_cast <uint8 const*> (m_buffer))
protected:
struct CtorParams
{
CtorParams (uint8 const* begin_, std::size_t bytes_)
: begin (begin_)
, bytes (bytes_)
{
}
uint8 const* begin;
std::size_t bytes;
};
FixedInputBuffer (CtorParams const& params)
: m_begin (params.begin)
, m_iter (m_begin)
, m_end (m_begin + params.bytes)
{
}
uint8 operator[] (std::size_t index) const noexcept
public:
FixedInputBuffer (FixedInputBuffer const& other)
: m_begin (other.m_begin)
, m_iter (other.m_iter)
, m_end (other.m_end)
{
bassert (index >= 0 && index < m_size);
return m_data [index];
}
bool peek (std::size_t bytes) const noexcept
FixedInputBuffer& operator= (FixedInputBuffer const& other)
{
if (m_size >= bytes)
return true;
return false;
m_begin = other.m_begin;
m_iter = other.m_iter;
m_end = other.m_end;
return *this;
}
// Returns the number of bytes consumed
std::size_t used () const noexcept
{
return m_iter - m_begin;
}
// Returns the size of what's remaining
std::size_t size () const noexcept
{
return m_end - m_iter;
}
void const* peek (std::size_t bytes)
{
return peek_impl (bytes, nullptr);
}
template <typename T>
bool peek (T* t) noexcept
bool peek (T* t) const noexcept
{
std::size_t const bytes = sizeof (T);
if (m_size >= bytes)
{
std::copy (m_data, m_data + bytes, t);
return true;
}
return false;
return peek_impl (sizeof (T), t) != nullptr;
}
bool consume (std::size_t bytes) noexcept
{
if (m_size >= bytes)
{
m_data += bytes;
m_size -= bytes;
return true;
}
return false;
return read_impl (bytes, nullptr) != nullptr;
}
bool read (std::size_t bytes) noexcept
{
return read_impl (bytes, nullptr) != nullptr;
}
template <typename T>
bool read (T* t) noexcept
{
std::size_t const bytes = sizeof (T);
if (m_size >= bytes)
{
//this causes a stack corruption.
//std::copy (m_data, m_data + bytes, t);
return read_impl (sizeof (T), t) != nullptr;
}
memcpy (t, m_data, bytes);
m_data += bytes;
m_size -= bytes;
return true;
}
return false;
uint8 operator[] (std::size_t index) const noexcept
{
bassert (index >= 0 && index < size ());
return m_iter [index];
}
// Reads an integraltype in network byte order
@@ -97,16 +117,77 @@ struct FixedInputBuffer
//static_bassert (std::is_integral <IntegerType>::value);
IntegerType networkValue;
if (! read (&networkValue))
return;
return false;
*value = fromNetworkByteOrder (networkValue);
return true;
}
protected:
void const* peek_impl (std::size_t bytes, void* buffer) const noexcept
{
if (size () >= bytes)
{
if (buffer != nullptr)
memcpy (buffer, m_iter, bytes);
return m_iter;
}
return nullptr;
}
void const* read_impl (std::size_t bytes, void* buffer) noexcept
{
if (size () >= bytes)
{
if (buffer != nullptr)
memcpy (buffer, m_iter, bytes);
void const* data = m_iter;
m_iter += bytes;
return data;
}
return nullptr;
}
private:
uint8 const* m_begin;
uint8 const* m_iter;
uint8 const* m_end;
};
//------------------------------------------------------------------------------
template <int Bytes>
class FixedInputBufferSize : public FixedInputBuffer
{
protected:
struct SizedCtorParams
{
template <typename ConstBufferSequence, typename Storage>
SizedCtorParams (ConstBufferSequence const& buffers, Storage& storage)
{
MutableBuffer buffer (boost::asio::buffer (storage));
data = boost::asio::buffer_cast <uint8 const*> (buffer);
bytes = boost::asio::buffer_copy (buffer, buffers);
}
operator CtorParams () const noexcept
{
return CtorParams (data, bytes);
}
uint8 const* data;
std::size_t bytes;
};
public:
template <typename ConstBufferSequence>
explicit FixedInputBufferSize (ConstBufferSequence const& buffers)
: FixedInputBuffer (SizedCtorParams (buffers, m_storage))
{
}
private:
boost::array <uint8, Bytes> m_storage;
MutableBuffer m_buffer;
std::size_t m_size;
uint8 const* m_data;
};
#endif

View File

@@ -30,7 +30,7 @@ namespace beast
#include "sockets/beast_Socket.cpp"
#include "sockets/beast_SslContext.cpp"
#include "handshake/beast_ProxyHandshake.cpp"
#include "handshake/beast_HandshakeDetectLogicPROXY.cpp"
#include "tests/beast_TestPeerBasics.cpp"
#include "tests/beast_TestPeerLogic.cpp"

View File

@@ -60,8 +60,9 @@ namespace beast
#include "sockets/beast_SocketWrapper.h"
#include "sockets/beast_SslContext.h"
#include "handshake/beast_ProxyHandshake.h"
#include "handshake/beast_HandshakeDetectLogic.h"
#include "handshake/beast_InputParser.h"
#include "handshake/beast_HandshakeDetectLogic.h"
#include "handshake/beast_HandshakeDetectLogicPROXY.h"
#include "handshake/beast_HandshakeDetectLogicSSL2.h"
#include "handshake/beast_HandshakeDetectLogicSSL3.h"
#include "handshake/beast_HandshakeDetectStream.h"

View File

@@ -42,6 +42,13 @@ public:
*/
virtual std::size_t max_needed () = 0;
/** How many bytes the handshake consumes.
If the detector processes the entire handshake this will
be non zero. The SSL detector would return 0, since we
want all the existing bytes to be passed on.
*/
virtual std::size_t bytes_consumed () = 0;
/** Return true if we have enough data to form a conclusion.
*/
bool finished () const noexcept
@@ -94,19 +101,24 @@ public:
return m_logic;
}
std::size_t max_needed () noexcept
std::size_t max_needed ()
{
return m_logic.max_needed ();
}
bool finished () noexcept
std::size_t bytes_consumed ()
{
return m_logic.bytes_consumed ();
}
bool finished ()
{
return m_logic.finished ();
}
/** If finished is true, this tells us if the handshake was detected.
*/
bool success () noexcept
bool success ()
{
return m_logic.success ();
}

View File

@@ -0,0 +1,20 @@
//------------------------------------------------------------------------------
/*
This file is part of Beast: https://github.com/vinniefalco/Beast
Copyright 2013, Vinnie Falco <vinnie.falco@gmail.com>
Permission to use, copy, modify, and/or distribute this software for any
purpose with or without fee is hereby granted, provided that the above
copyright notice and this permission notice appear in all copies.
THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
ANY SPECIAL , DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
*/
//==============================================================================

View File

@@ -0,0 +1,159 @@
//------------------------------------------------------------------------------
/*
This file is part of Beast: https://github.com/vinniefalco/Beast
Copyright 2013, Vinnie Falco <vinnie.falco@gmail.com>
Permission to use, copy, modify, and/or distribute this software for any
purpose with or without fee is hereby granted, provided that the above
copyright notice and this permission notice appear in all copies.
THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
ANY SPECIAL , DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
*/
//==============================================================================
#ifndef BEAST_HANDSHAKEDETECTLOGICPROXY_H_INCLUDED
#define BEAST_HANDSHAKEDETECTLOGICPROXY_H_INCLUDED
/** Handshake detector for the PROXY protcol
http://haproxy.1wt.eu/download/1.5/doc/proxy-protocol.txt
*/
class HandshakeDetectLogicPROXY : public HandshakeDetectLogic
{
public:
typedef int arg_type;
enum
{
// This is for version 1. The largest number of bytes
// that we could possibly need to parse a valid handshake.
// We will reject it much sooner if there's an illegal value.
//
maxBytesNeeded = 107 // including CRLF, no null term
};
struct ProxyInfo
{
typedef InputParser::IPv4Address IPv4Address;
String protocol; // "TCP4", "TCP6", "UNKNOWN"
IPv4Address sourceAddress;
IPv4Address destAddress;
uint16 sourcePort;
uint16 destPort;
};
explicit HandshakeDetectLogicPROXY (arg_type const&)
: m_consumed (0)
{
}
ProxyInfo const& getInfo () const noexcept
{
return m_info;
}
std::size_t max_needed ()
{
return maxBytesNeeded;
}
std::size_t bytes_consumed ()
{
return m_consumed;
}
template <typename ConstBufferSequence>
void analyze (ConstBufferSequence const& buffer)
{
FixedInputBufferSize <maxBytesNeeded> in (buffer);
InputParser::State state;
analyze_input (in, state);
if (state.passed ())
{
m_consumed = in.used ();
conclude (true);
}
else if (state.failed ())
{
conclude (false);
}
}
void analyze_input (FixedInputBuffer& in, InputParser::State& state)
{
using namespace InputParser;
if (! match (in, "PROXY "), state)
return;
if (match (in, "TCP4 "))
{
m_info.protocol = "TCP4";
if (! read (in, m_info.sourceAddress, state))
return;
if (! match (in, " ", state))
return;
if (! read (in, m_info.destAddress, state))
return;
if (! match (in, " ", state))
return;
UInt16Str sourcePort;
if (! read (in, sourcePort, state))
return;
m_info.sourcePort = sourcePort.value;
if (! match (in, " ", state))
return;
UInt16Str destPort;
if (! read (in, destPort, state))
return;
m_info.destPort = destPort.value;
if (! match (in, "\r\n", state))
return;
state = State::pass;
return;
}
else if (match (in, "TCP6 "))
{
m_info.protocol = "TCP6";
state = State::fail;
return;
}
else if (match (in, "UNKNOWN "))
{
m_info.protocol = "UNKNOWN";
state = State::fail;
return;
}
state = State::fail;
}
private:
std::size_t m_consumed;
ProxyInfo m_info;
};
#endif

View File

@@ -53,10 +53,15 @@ public:
return bytesNeeded;
}
std::size_t bytes_consumed ()
{
return 0;
}
template <typename ConstBufferSequence>
void analyze (ConstBufferSequence const& buffer)
{
FixedInputBuffer <bytesNeeded> in (buffer);
FixedInputBufferSize <bytesNeeded> in (buffer);
{
uint8 byte;

View File

@@ -45,11 +45,16 @@ public:
return bytesNeeded;
}
std::size_t bytes_consumed ()
{
return 0;
}
template <typename ConstBufferSequence>
void analyze (ConstBufferSequence const& buffer)
{
uint16 version;
FixedInputBuffer <bytesNeeded> in (buffer);
FixedInputBufferSize <bytesNeeded> in (buffer);
uint8 msg_type;
if (! in.read (&msg_type))

View File

@@ -97,6 +97,18 @@ public:
{
}
// This puts bytes that you already have into the detector buffer
// Any leftovers will be given to the callback.
// A copy of the data is made.
//
template <typename ConstBufferSequence>
void fill (ConstBufferSequence const& buffers)
{
m_buffer.commit (boost::asio::buffer_copy (
m_buffer.prepare (boost::asio::buffer_size (buffers)),
buffers));
}
// basic_io_object
boost::asio::io_service& get_io_service ()
@@ -203,6 +215,10 @@ public:
if (m_logic.finished ())
{
// consume what we used (for SSL its 0)
std::size_t const consumed = m_logic.bytes_consumed ();
bassert (consumed <= m_buffer.size ());
m_buffer.consume (consumed);
m_callback->on_detect (m_logic.get (), ec,
ConstBuffers (m_buffer.data ()));
break;
@@ -240,10 +256,16 @@ public:
std::size_t const available = m_buffer.size ();
std::size_t const needed = m_logic.max_needed ();
m_logic.analyze (m_buffer.data ());
if (bytes_transferred > 0)
m_logic.analyze (m_buffer.data ());
if (m_logic.finished ())
{
// consume what we used (for SSL its 0)
std::size_t const consumed = m_logic.bytes_consumed ();
bassert (consumed <= m_buffer.size ());
m_buffer.consume (consumed);
#if BEAST_ASIO_HAS_BUFFEREDHANDSHAKE
if (! m_origBufferedHandler.isNull ())
{

View File

@@ -0,0 +1,379 @@
//------------------------------------------------------------------------------
/*
This file is part of Beast: https://github.com/vinniefalco/Beast
Copyright 2013, Vinnie Falco <vinnie.falco@gmail.com>
Permission to use, copy, modify, and/or distribute this software for any
purpose with or without fee is hereby granted, provided that the above
copyright notice and this permission notice appear in all copies.
THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
ANY SPECIAL , DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
*/
//==============================================================================
#ifndef BEAST_INPUTPARSER_H_INCLUDED
#define BEAST_INPUTPARSER_H_INCLUDED
namespace InputParser
{
/** Tri-valued parsing state.
This is convertible to bool which means continue.
Or you can use stop() to decide if you should return.
After a stop you can use failed () to determine if parsing failed.
*/
struct State : SafeBool <State>
{
enum State_t
{
pass, // passed the parse
fail, // failed the parse
more // didn't fail but need more bytes
};
State () : m_state (more) { }
State (State_t state) : m_state (state) { }
/** Implicit construction from bool.
If condition is true then the parse passes, else need more.
*/
State (bool condition) : m_state (condition ? pass : more) { }
State& operator= (State_t state) { m_state = state; return *this; }
bool eof () const noexcept { return m_state == more; }
bool stop () const noexcept { return m_state != pass; }
bool passed () const noexcept { return m_state == pass; }
bool failed () const noexcept { return m_state == fail; }
bool asBoolean () const noexcept { return m_state == pass; } // for SafeBool<>
private:
State_t m_state;
};
//------------------------------------------------------------------------------
// Shortcut to save typing.
typedef FixedInputBuffer& Input;
/** Specializations implement the get() function. */
template <class T>
struct Get;
/** Specializations implement the match() function.
Default implementation of match tries to read it into a local.
*/
template <class T>
struct Match
{
static State func (Input in, T other)
{
T t;
State state = Get <T>::func (in, t);
if (state.passed ())
{
if (t == other)
return State::passed;
return State::failed;
}
return state;
}
};
/** Specializations implement the peek() function.
Default implementation of peek reads and rewinds.
*/
template <class T>
struct Peek
{
static State func (Input in, T& t)
{
Input dup (in);
return Get <T>::func (dup, t);
}
};
//------------------------------------------------------------------------------
//
// Free Functions
//
//------------------------------------------------------------------------------
// match a block of data in memory
//
static State match_buffer (Input in, void const* buffer, std::size_t bytes)
{
bassert (bytes > 0);
if (in.size () <= 0)
return State::more;
std::size_t const have = std::min (in.size (), bytes);
void const* data = in.peek (have);
bassert (data != nullptr);
int const compare = memcmp (data, buffer, have);
if (compare != 0)
return State::fail;
in.consume (have);
return have == bytes;
}
//------------------------------------------------------------------------------
//
// match
//
// Returns the state
template <class T>
State match (Input in, T t)
{
return Match <T>::func (in, t);
}
// Stores the state in the argument and returns true if its a pass
template <class T>
bool match (Input in, T t, State& state)
{
return (state = match (in, t)).passed ();
}
//------------------------------------------------------------------------------
//
// peek
//
// Returns the state
template <class T>
State peek (Input in, T& t)
{
return Peek <T>::func (in, t);
}
// Stores the state in the argument and returns true if its a pass
template <class T>
bool peek (Input in, T& t, State& state)
{
return (state = peek (in, t)).passed ();
}
//------------------------------------------------------------------------------
//
// read
//
// Returns the state
template <class T>
State read (Input in, T& t)
{
return Get <T>::func (in, t);
}
// Stores the state in the argument and returns true if its a pass
template <class T>
bool read (Input in, T& t, State& state)
{
return (state = read (in, t)).passed ();
}
//------------------------------------------------------------------------------
//
// Specializations for basic types
//
template <>
struct Match <char const*>
{
static State func (Input in, char const* text)
{
return InputParser::match_buffer (in, text, strlen (text));
}
};
//------------------------------------------------------------------------------
//
// Special types and their specializations
//
struct Digit
{
int value;
};
template <>
struct Get <Digit>
{
static State func (Input in, Digit& t)
{
char c;
if (! in.peek (&c))
return State::more;
if (! std::isdigit (c))
return State::fail;
in.consume (1);
t.value = c - '0';
return State::pass;
}
};
//------------------------------------------------------------------------------
// An unsigned 32 bit number expressed as a string
struct UInt32Str
{
uint32 value;
};
template <>
struct Get <UInt32Str>
{
static State func (Input in, UInt32Str& t)
{
State state;
uint32 value (0);
Digit digit;
// have to have at least one digit
if (! read (in, digit, state))
return state;
value = digit.value;
for (;;)
{
state = peek (in, digit);
if (state.failed ())
{
t.value = value;
return State::pass;
}
else if (state.eof ())
{
t.value = value;
return state;
}
// can't have a digit following a zero
if (value == 0)
return State::fail;
uint32 newValue = (value * 10) + digit.value;
// overflow
if (newValue < value)
return State::fail;
value = newValue;
}
return State::fail;
}
};
//------------------------------------------------------------------------------
// An unsigned 16 bit number expressed as a string
struct UInt16Str
{
uint16 value;
};
template <>
struct Get <UInt16Str>
{
static State func (Input in, UInt16Str& t)
{
UInt32Str v;
State state = read (in, v);
if (state.passed ())
{
if (v.value <= 65535)
{
t.value = uint16(v.value);
return State::pass;
}
return State::fail;
}
return state;
}
};
//------------------------------------------------------------------------------
// An unsigned 8 bit number expressed as a string
struct UInt8Str
{
uint8 value;
};
template <>
struct Get <UInt8Str>
{
static State func (Input in, UInt8Str& t)
{
UInt32Str v;
State state = read (in, v);
if (state.passed ())
{
if (v.value <= 255)
{
t.value = uint8(v.value);
return State::pass;
}
return State::fail;
}
return state;
}
};
//------------------------------------------------------------------------------
// An dotted IPv4 address
struct IPv4Address
{
uint8 value [4];
};
template <>
struct Get <IPv4Address>
{
static State func (Input in, IPv4Address& t)
{
State state;
UInt8Str digit [4];
if (! read (in, digit [0], state))
return state;
if (! match (in, ".", state))
return state;
if (! read (in, digit [1], state))
return state;
if (! match (in, ".", state))
return state;
if (! read (in, digit [2], state))
return state;
if (! match (in, ".", state))
return state;
if (! read (in, digit [3], state))
return state;
t.value [0] = digit [0].value;
t.value [1] = digit [1].value;
t.value [2] = digit [2].value;
t.value [3] = digit [3].value;
return State::pass;
}
};
//------------------------------------------------------------------------------
}
#endif

View File

@@ -1,384 +0,0 @@
//------------------------------------------------------------------------------
/*
This file is part of Beast: https://github.com/vinniefalco/Beast
Copyright 2013, Vinnie Falco <vinnie.falco@gmail.com>
Permission to use, copy, modify, and/or distribute this software for any
purpose with or without fee is hereby granted, provided that the above
copyright notice and this permission notice appear in all copies.
THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
ANY SPECIAL , DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
*/
//==============================================================================
ProxyHandshake::ProxyHandshake (bool expectHandshake)
: m_status (expectHandshake ? statusHandshake : statusNone)
, m_gotCR (false)
{
m_buffer.preallocateBytes (maxVersion1Bytes);
}
ProxyHandshake::~ProxyHandshake ()
{
}
std::size_t ProxyHandshake::feed (void const* inputBuffer, size_t inputBytes)
{
std::size_t bytesConsumed = 0;
char const* p = static_cast <char const*> (inputBuffer);
if (m_status == statusHandshake)
{
if (! m_gotCR)
{
while (inputBytes > 0 && m_buffer.length () < maxVersion1Bytes - 1)
{
beast_wchar c = *p++;
++bytesConsumed;
--inputBytes;
m_buffer += c;
if (c == '\r')
{
m_gotCR = true;
break;
}
else if (c == '\n')
{
m_status = statusFailed;
}
}
if (m_buffer.length () > maxVersion1Bytes - 1)
{
m_status = statusFailed;
}
}
}
if (m_status == statusHandshake)
{
if (inputBytes > 0 && m_gotCR)
{
bassert (m_buffer.length () < maxVersion1Bytes);
char const lf ('\n');
if (*p == lf)
{
++bytesConsumed;
--inputBytes;
m_buffer += lf;
parseLine ();
}
else
{
m_status = statusFailed;
}
}
}
return bytesConsumed;
}
void ProxyHandshake::parseLine ()
{
Version1 p;
bool success = p.parse (m_buffer.getCharPointer (), m_buffer.length ());
if (success)
{
m_endpoints = p.endpoints;
m_status = statusOk;
}
else
{
m_status = statusFailed;
}
}
int ProxyHandshake::indexOfFirstNonNumber (String const& input)
{
bassert (input.length () > 0);
int i = 0;
for (; i < input.length (); ++i)
{
if (! CharacterFunctions::isDigit (input [i]))
break;
}
return i;
}
bool ProxyHandshake::chop (String const& what, String& input)
{
if (input.startsWith (what))
{
input = input.substring (what.length ());
return true;
}
return false;
}
bool ProxyHandshake::chopUInt (int* value, int limit, String& input)
{
if (input.length () <= 0)
return false;
String const s = input.substring (0, indexOfFirstNonNumber (input));
if (s.length () <= 0)
return false;
int const n = s.getIntValue ();
// Leading zeroes disallowed as per spec, to prevent confusion with octal
if (String (n) != s)
return false;
if (n < 0 || n > limit)
return false;
input = input.substring (s.length ());
*value = n;
return true;
}
//------------------------------------------------------------------------------
/*
steps:
Proxy protocol lets us filter attackers by learning the source ip and port
1. Determine if we should use the proxy on a connection
- Port just for proxy protocol connections
- Filter on source IPs
2. Read a line from the connection to get the proxy information
3. Parse the line (human or binary?)
4. Code Interface to retrieve proxy information (ip/port) on connection
*/
ProxyHandshake::Version1::Version1 ()
{
}
bool ProxyHandshake::IPv4::Addr::chop (String& input)
{
if (!ProxyHandshake::chopUInt (&a, 255, input))
return false;
if (!ProxyHandshake::chop (".", input))
return false;
if (!ProxyHandshake::chopUInt (&b, 255, input))
return false;
if (!ProxyHandshake::chop (".", input))
return false;
if (!ProxyHandshake::chopUInt (&c, 255, input))
return false;
if (!ProxyHandshake::chop (".", input))
return false;
if (!ProxyHandshake::chopUInt (&d, 255, input))
return false;
return true;
}
bool ProxyHandshake::Version1::parse (void const* headerData, size_t headerBytes)
{
String input (static_cast <CharPointer_UTF8::CharType const*> (headerData), headerBytes);
if (input.length () < 2)
return false;
if (! input.endsWith ("\r\n"))
return false;
input = input.dropLastCharacters (2);
if (! ProxyHandshake::chop ("PROXY ", input))
return false;
if (ProxyHandshake::chop ("UNKNOWN", input))
{
endpoints.proto = protoUnknown;
input = "";
}
else
{
if (ProxyHandshake::chop ("TCP4 ", input))
{
endpoints.proto = protoTcp4;
if (! endpoints.ipv4.sourceAddr.chop (input))
return false;
if (! ProxyHandshake::chop (" ", input))
return false;
if (! endpoints.ipv4.destAddr.chop (input))
return false;
if (! ProxyHandshake::chop (" ", input))
return false;
if (! ProxyHandshake::chopUInt (&endpoints.ipv4.sourcePort, 65535, input))
return false;
if (! ProxyHandshake::chop (" ", input))
return false;
if (! ProxyHandshake::chopUInt (&endpoints.ipv4.destPort, 65535, input))
return false;
}
else if (ProxyHandshake::chop ("TCP6 ", input))
{
endpoints.proto = protoTcp6;
//bassertfalse;
return false;
}
else
{
return false;
}
}
// Can't have anything extra between the last port number and the CRLF
if (input.length () > 0)
return false;
return true;
}
//------------------------------------------------------------------------------
class ProxyHandshakeTests : public UnitTest
{
public:
ProxyHandshakeTests () : UnitTest ("ProxyHandshake", "beast")
{
}
static std::string goodIpv4 ()
{
return "PROXY TCP4 255.255.255.255 255.255.255.255 65535 65535\r\n"; // 56 chars
}
static std::string goodIpv6 ()
{
return "PROXY TCP6 fffffffffffffffffffffffffffffffffffffff.fffffffffffffffffffffffffffffffffffffff 65535 65535\r\n";
//1234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123 4 (104 chars)
}
static std::string goodUnknown ()
{
return "PROXY UNKNOWN\r\n";
}
static std::string goodUnknownBig ()
{
return "PROXY UNKNOWN fffffffffffffffffffffffffffffffffffffff.fffffffffffffffffffffffffffffffffffffff 65535 65535\r\n";
//1234567890123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456 7 (107 chars)
}
void testHandshake (std::string const& s, bool shouldSucceed)
{
if (s.size () > 1)
{
ProxyHandshake h (true);
expect (h.getStatus () == ProxyHandshake::statusHandshake);
for (std::size_t i = 0; i < s.size () && h.getStatus () == ProxyHandshake::statusHandshake ; ++i)
{
std::size_t const bytesConsumed = h.feed (& s[i], 1);
if (i != s.size () - 1)
expect (h.getStatus () == ProxyHandshake::statusHandshake);
expect (bytesConsumed == 1);
}
if (shouldSucceed)
{
expect (h.getStatus () == ProxyHandshake::statusOk);
}
else
{
expect (h.getStatus () == ProxyHandshake::statusFailed);
}
}
else
{
bassertfalse;
}
}
void testVersion1String (std::string const& s, bool shouldSucceed)
{
ProxyHandshake::Version1 p;
if (shouldSucceed)
{
expect (p.parse (s.c_str (), s.size ()));
}
else
{
unexpected (p.parse (s.c_str (), s.size ()));
}
for (std::size_t i = 1; i < s.size () - 1; ++i)
{
String const partial = String (s).dropLastCharacters (i);
std::string ss (partial.toStdString ());
expect (! p.parse (ss.c_str (), ss.size ()));
}
testHandshake (s, shouldSucceed);
}
void testVersion1 ()
{
beginTestCase ("version1");
testVersion1String (goodIpv4 (), true);
testVersion1String (goodIpv6 (), false);
testVersion1String (goodUnknown (), true);
testVersion1String (goodUnknownBig (), true);
}
void runTest ()
{
testVersion1 ();
}
};
static ProxyHandshakeTests proxyHandshakeTests;

View File

@@ -1,164 +0,0 @@
//------------------------------------------------------------------------------
/*
This file is part of Beast: https://github.com/vinniefalco/Beast
Copyright 2013, Vinnie Falco <vinnie.falco@gmail.com>
Permission to use, copy, modify, and/or distribute this software for any
purpose with or without fee is hereby granted, provided that the above
copyright notice and this permission notice appear in all copies.
THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
ANY SPECIAL , DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
*/
//==============================================================================
#ifndef BEAST_PROXYYHANDSHAKE_H_INCLUDED
#define BEAST_PROXYYHANDSHAKE_H_INCLUDED
/** PROXY protocol handshake state machine.
The PROXY Protocol:
http://haproxy.1wt.eu/download/1.5/doc/proxy-protocol.txt
*/
class ProxyHandshake
{
public:
/** Status of the handshake state machine. */
enum Status
{
statusNone, // No handshake expected
statusHandshake, // Handshake in progress
statusFailed, // Handshake failed
statusOk, // Handshake succeeded
};
enum Proto
{
protoTcp4,
protoTcp6,
protoUnknown
};
/** PROXY information for IPv4 families. */
struct IPv4
{
struct Addr
{
int a;
int b;
int c;
int d;
bool chop (String& input);
};
Addr sourceAddr;
Addr destAddr;
int sourcePort;
int destPort;
};
/** PROXY information for IPv6 families. */
struct IPv6
{
struct Addr
{
int a;
int b;
int c;
int d;
};
Addr sourceAddr;
Addr destAddr;
int sourcePort;
int destPort;
};
/** Fully decoded PROXY information. */
struct Endpoints
{
Endpoints ()
: proto (protoUnknown)
{
}
Proto proto;
IPv4 ipv4; // valid if proto == protoTcp4
IPv6 ipv6; // valid if proto == protoTcp6;
};
//--------------------------------------------------------------------------
/** Parser for PROXY version 1. */
struct Version1
{
enum
{
// Maximum input buffer size needed, including a null
// terminator, as per the PROXY protocol specification.
maxBufferBytes = 108
};
Endpoints endpoints;
Version1 ();
/** Parse the header.
@param rawHeader a pointer to the header data
@return `true` If it was parsed successfully.
*/
bool parse (void const* headerData, size_t headerBytes);
};
//--------------------------------------------------------------------------
/** Create the handshake state.
If a handshake is expected, then it is required.
@param wantHandshake `false` to skip handshaking.
*/
explicit ProxyHandshake (bool expectHandshake = false);
~ProxyHandshake ();
inline Status getStatus () const noexcept
{
return m_status;
}
inline Endpoints const& getEndpoints () const noexcept
{
return m_endpoints;
};
/** Feed the handshaking state engine.
@return The number of bytes consumed in the input buffer.
*/
std::size_t feed (void const* inputBuffer, std::size_t inputBytes);
// Utility functions used by parsers
static int indexOfFirstNonNumber (String const& input);
static bool chop (String const& what, String& input);
static bool chopUInt (int* value, int limit, String& input);
private:
void parseLine ();
private:
enum
{
maxVersion1Bytes = 107 // including crlf, not including null term
};
Status m_status;
String m_buffer;
bool m_gotCR;
Endpoints m_endpoints;
};
#endif

View File

@@ -24,7 +24,7 @@ TestPeerLogicProxyClient::TestPeerLogicProxyClient (Socket& socket)
void TestPeerLogicProxyClient::on_pre_handshake ()
{
ProxyHandshake h;
//ProxyHandshakeParser h;
static std::string line (
"PROXY TCP4 255.255.255.255 255.255.255.255 65535 65535\r\n"