diff --git a/src/cpp/ripple/AutoSocket.cpp b/src/cpp/ripple/AutoSocket.cpp new file mode 100644 index 0000000000..f421937fa4 --- /dev/null +++ b/src/cpp/ripple/AutoSocket.cpp @@ -0,0 +1,28 @@ + +#include "AutoSocket.h" + +#include + +void AutoSocket::handle_autodetect(const error_code& ec) +{ + if (ec) + { + if (mCallback) + mCallback(ec); + return; + } + + if ((mBuffer[0] < 127) && (mBuffer[0] > 31) && + (mBuffer[1] < 127) && (mBuffer[1] > 31) && + (mBuffer[2] < 127) && (mBuffer[2] > 31) && + (mBuffer[3] < 127) && (mBuffer[3] > 31)) + { // non-SSL + if (mCallback) + mCallback(ec); + } + else + { // ssl + mSecure = true; + SSLSocket().async_handshake(ssl_socket::server, mCallback); + } +} diff --git a/src/cpp/ripple/AutoSocket.h b/src/cpp/ripple/AutoSocket.h new file mode 100644 index 0000000000..019618dbd8 --- /dev/null +++ b/src/cpp/ripple/AutoSocket.h @@ -0,0 +1,94 @@ +#ifndef __AUTOSOCKET_H_ +#define __AUTOSOCKET_H_ + +#include + +#include +#include +#include +#include + +// Socket wrapper that supports both SSL and non-SSL connections. +// Generally, handle it as you would an SSL connection. +// For outbound non-SSL connections, just don't call async_handshake. + +namespace basio = boost::asio; +namespace bassl = basio::ssl; + +class AutoSocket +{ +public: + typedef bassl::stream ssl_socket; + typedef ssl_socket::next_layer_type plain_socket; + typedef boost::system::error_code error_code; + typedef boost::function callback; + +protected: + ssl_socket mSocket; + bool mSecure; + callback mCallback; + + std::vector mBuffer; + +public: + AutoSocket(basio::io_service& s, bassl::context& c) : mSocket(s, c), mSecure(false), mBuffer(4) { ; } + + bool isSecure() { return mSecure; } + ssl_socket& SSLSocket() { return mSocket; } + plain_socket& PlainSocket() { return mSocket.next_layer(); } + + void async_handshake(ssl_socket::handshake_type type, callback cbFunc) + { + mSecure = true; + if (type == ssl_socket::client) + SSLSocket().async_handshake(type, cbFunc); + else + { + mCallback = cbFunc; + PlainSocket().async_receive(basio::buffer(mBuffer), basio::socket_base::message_peek, + boost::bind(&AutoSocket::handle_autodetect, this, basio::placeholders::error)); + + } + } + + template StreamType& getSocket() + { + if (isSecure()) + return SSLSocket(); + if (!isSecure()) + return PlainSocket(); + } + + template void async_shutdown(ShutdownHandler handler) + { + if (isSecure()) + SSLSocket().async_shutdown(handler); + else + { + PlainSocket().shutdown(plain_socket::shutdown_both); + if (handler) + mSocket.get_io_service().post(handler); + } + } + + template void async_read_some(const Seq& buffers, Handler handler) + { + if (isSecure()) + SSLSocket().async_read_some(buffers, handler); + else + PlainSocket().async_read_some(buffers, handler); + } + + template void async_write_some(const Seq& buffers, Handler handler) + { + if (isSecure()) + SSLSocket().async_write_some(buffers, handler); + else + PlainSocket().async_write_some(buffers, handler); + } + +protected: + void handle_autodetect(const error_code&); +}; + +#endif