diff --git a/src/cpp/ripple/AutoSocket.cpp b/src/cpp/ripple/AutoSocket.cpp new file mode 100644 index 000000000..dd553a4c5 --- /dev/null +++ b/src/cpp/ripple/AutoSocket.cpp @@ -0,0 +1,32 @@ + +#include "AutoSocket.h" + +#include + +void AutoSocket::handle_autodetect(const error_code& ec) +{ + if (ec) + { + if (mCallback) + mCallback(ec); + return; + } + + if ((mBuffer[0] < 127) && (mBuffer[0] > 31) && + (mBuffer[1] < 127) && (mBuffer[1] > 31) && + (mBuffer[2] < 127) && (mBuffer[2] > 31) && + (mBuffer[3] < 127) && (mBuffer[3] > 31)) + { // non-SSL + mSecure = false; + if (mCallback) + mCallback(ec); + } + else + { // ssl + mSecure = true; + SSLSocket().async_handshake(ssl_socket::server, mCallback); + mCallback = callback(); + } +} + +// vim:ts=4 diff --git a/src/cpp/ripple/AutoSocket.h b/src/cpp/ripple/AutoSocket.h new file mode 100644 index 000000000..d8529ea28 --- /dev/null +++ b/src/cpp/ripple/AutoSocket.h @@ -0,0 +1,99 @@ +#ifndef __AUTOSOCKET_H_ +#define __AUTOSOCKET_H_ + +#include + +#include +#include +#include +#include + +// Socket wrapper that supports both SSL and non-SSL connections. +// Generally, handle it as you would an SSL connection. +// To force a non-SSL connection, just don't call async_handshake. +// To force SSL only inbound, call setSSLOnly. + +namespace basio = boost::asio; +namespace bassl = basio::ssl; + +class AutoSocket +{ +public: + typedef bassl::stream ssl_socket; + typedef ssl_socket::next_layer_type plain_socket; + typedef boost::system::error_code error_code; + typedef boost::function callback; + +protected: + ssl_socket mSocket; + bool mSecure; + callback mCallback; + + std::vector mBuffer; + +public: + AutoSocket(basio::io_service& s, bassl::context& c) : mSocket(s, c), mSecure(false), mBuffer(4) { ; } + + bool isSecure() { return mSecure; } + ssl_socket& SSLSocket() { return mSocket; } + plain_socket& PlainSocket() { return mSocket.next_layer(); } + + void setSSLOnly() { mBuffer.clear(); } + + void async_handshake(ssl_socket::handshake_type type, callback cbFunc) + { + mSecure = true; + if ((type == ssl_socket::client) || (mBuffer.empty())) + SSLSocket().async_handshake(type, cbFunc); + else + { + mCallback = cbFunc; + PlainSocket().async_receive(basio::buffer(mBuffer), basio::socket_base::message_peek, + boost::bind(&AutoSocket::handle_autodetect, this, basio::placeholders::error)); + + } + } + + template StreamType& getSocket() + { + if (isSecure()) + return SSLSocket(); + if (!isSecure()) + return PlainSocket(); + } + + template void async_shutdown(ShutdownHandler handler) + { + if (isSecure()) + SSLSocket().async_shutdown(handler); + else + { + PlainSocket().shutdown(plain_socket::shutdown_both); + if (handler) + mSocket.get_io_service().post(handler); + } + } + + template void async_read_some(const Seq& buffers, Handler handler) + { + if (isSecure()) + SSLSocket().async_read_some(buffers, handler); + else + PlainSocket().async_read_some(buffers, handler); + } + + template void async_write_some(const Seq& buffers, Handler handler) + { + if (isSecure()) + SSLSocket().async_write_some(buffers, handler); + else + PlainSocket().async_write_some(buffers, handler); + } + +protected: + void handle_autodetect(const error_code&); +}; + +#endif + +// vim:ts=4