This commit is contained in:
JoelKatz
2013-01-24 12:50:48 -08:00
parent a6d189e2da
commit dba8ae2db6
2 changed files with 39 additions and 50 deletions

View File

@@ -1,32 +0,0 @@
#include "AutoSocket.h"
#include <boost/bind.hpp>
void AutoSocket::handle_autodetect(const error_code& ec)
{
if (ec)
{
if (mCallback)
mCallback(ec);
return;
}
if ((mBuffer[0] < 127) && (mBuffer[0] > 31) &&
(mBuffer[1] < 127) && (mBuffer[1] > 31) &&
(mBuffer[2] < 127) && (mBuffer[2] > 31) &&
(mBuffer[3] < 127) && (mBuffer[3] > 31))
{ // non-SSL
mSecure = false;
if (mCallback)
mCallback(ec);
}
else
{ // ssl
mSecure = true;
SSLSocket().async_handshake(ssl_socket::server, mCallback);
mCallback = callback();
}
}
// vim:ts=4

View File

@@ -8,6 +8,9 @@
#include <boost/asio.hpp>
#include <boost/asio/ssl.hpp>
#include "Log.h"
extern LogPartition AutoSocketPartition;
// Socket wrapper that supports both SSL and non-SSL connections.
// Generally, handle it as you would an SSL connection.
// To force a non-SSL connection, just don't call async_handshake.
@@ -21,13 +24,14 @@ class AutoSocket
public:
typedef bassl::stream<basio::ip::tcp::socket> ssl_socket;
typedef ssl_socket::next_layer_type plain_socket;
typedef ssl_socket::lowest_layer_type lowest_layer_type;
typedef ssl_socket::handshake_type handshake_type;
typedef boost::system::error_code error_code;
typedef boost::function<void(error_code)> callback;
typedef boost::function<void(error_code)> callback;
protected:
ssl_socket mSocket;
bool mSecure;
callback mCallback;
std::vector<char> mBuffer;
@@ -37,21 +41,20 @@ public:
bool isSecure() { return mSecure; }
ssl_socket& SSLSocket() { return mSocket; }
plain_socket& PlainSocket() { return mSocket.next_layer(); }
void setSSLOnly() { mBuffer.clear(); }
void async_handshake(ssl_socket::handshake_type type, callback cbFunc)
lowest_layer_type& lowest_layer() { return mSocket.lowest_layer(); }
void async_handshake(handshake_type type, callback cbFunc)
{
mSecure = true;
if ((type == ssl_socket::client) || (mBuffer.empty()))
SSLSocket().async_handshake(type, cbFunc);
else
{
mCallback = cbFunc;
PlainSocket().async_receive(basio::buffer(mBuffer), basio::socket_base::message_peek,
boost::bind(&AutoSocket::handle_autodetect, this, basio::placeholders::error));
mSecure = true;
mSocket.async_handshake(type, cbFunc);
}
else
mSocket.next_layer().async_receive(basio::buffer(mBuffer), basio::socket_base::message_peek,
boost::bind(&AutoSocket::handle_autodetect, this, cbFunc, basio::placeholders::error));
}
template <typename StreamType> StreamType& getSocket()
@@ -65,19 +68,18 @@ public:
template <typename ShutdownHandler> void async_shutdown(ShutdownHandler handler)
{
if (isSecure())
SSLSocket().async_shutdown(handler);
mSocket.async_shutdown(handler);
else
{
PlainSocket().shutdown(plain_socket::shutdown_both);
if (handler)
mSocket.get_io_service().post(handler);
lowest_layer().shutdown(plain_socket::shutdown_both);
mSocket.get_io_service().post(boost::bind(handler, error_code()));
}
}
template <typename Seq, typename Handler> void async_read_some(const Seq& buffers, Handler handler)
{
if (isSecure())
SSLSocket().async_read_some(buffers, handler);
mSocket.async_read_some(buffers, handler);
else
PlainSocket().async_read_some(buffers, handler);
}
@@ -85,13 +87,32 @@ public:
template <typename Seq, typename Handler> void async_write_some(const Seq& buffers, Handler handler)
{
if (isSecure())
SSLSocket().async_write_some(buffers, handler);
mSocket.async_write_some(buffers, handler);
else
PlainSocket().async_write_some(buffers, handler);
}
protected:
void handle_autodetect(const error_code&);
void handle_autodetect(callback cbFunc, const error_code& ec)
{
if (ec)
{
Log(lsWARNING, AutoSocketPartition) << "Handle autodetect error: " << ec;
cbFunc(ec);
}
else if ((mBuffer[0] < 127) && (mBuffer[0] > 31) &&
(mBuffer[1] < 127) && (mBuffer[1] > 31) &&
(mBuffer[2] < 127) && (mBuffer[2] > 31) &&
(mBuffer[3] < 127) && (mBuffer[3] > 31))
{ // not ssl
cbFunc(ec);
}
else
{ // ssl
mSecure = true;
mSocket.async_handshake(ssl_socket::server, cbFunc);
}
}
};
#endif