diff --git a/src/cpp/ripple/Application.cpp b/src/cpp/ripple/Application.cpp index ec034eedf..58b75e8b4 100644 --- a/src/cpp/ripple/Application.cpp +++ b/src/cpp/ripple/Application.cpp @@ -20,6 +20,7 @@ SETUP_LOG(); LogPartition TaggedCachePartition("TaggedCache"); +LogPartition AutoSocketPartition("AutoSocket"); Application* theApp = NULL; DatabaseCon::DatabaseCon(const std::string& strName, const char *initStrings[], int initCount) diff --git a/src/cpp/ripple/AutoSocket.h b/src/cpp/ripple/AutoSocket.h index 106c101d2..ccf82ab50 100644 --- a/src/cpp/ripple/AutoSocket.h +++ b/src/cpp/ripple/AutoSocket.h @@ -44,10 +44,17 @@ public: mSocket = boost::make_shared(boost::ref(s), boost::ref(c)); } + AutoSocket(basio::io_service& s, bassl::context& c, bool secureOnly, bool plainOnly) + : mSecure(secureOnly), mBuffer((plainOnly || secureOnly) ? 0 : 4) + { + mSocket = boost::make_shared(boost::ref(s), boost::ref(c)); + } + bool isSecure() { return mSecure; } ssl_socket& SSLSocket() { return *mSocket; } plain_socket& PlainSocket() { return mSocket->next_layer(); } - void setSSLOnly() { mBuffer.clear(); } + void setSSLOnly() { mSecure = true;} + void setPlainOnly() { mBuffer.clear(); } lowest_layer_type& lowest_layer() { return mSocket->lowest_layer(); } @@ -60,14 +67,21 @@ public: void async_handshake(handshake_type type, callback cbFunc) { - if ((type == ssl_socket::client) || (mBuffer.empty())) - { + if ((type == ssl_socket::client) || (mSecure)) + { // must be ssl mSecure = true; mSocket->async_handshake(type, cbFunc); } + else if (mBuffer.empty()) + { // must be plain + mSecure = false; + cbFunc(error_code()); + } else + { // autodetect mSocket->next_layer().async_receive(basio::buffer(mBuffer), basio::socket_base::message_peek, boost::bind(&AutoSocket::handle_autodetect, this, cbFunc, basio::placeholders::error)); + } } template void async_shutdown(ShutdownHandler handler) @@ -110,6 +124,7 @@ protected: (mBuffer[2] < 127) && (mBuffer[2] > 31) && (mBuffer[3] < 127) && (mBuffer[3] > 31)) { // not ssl + mSecure = false; cbFunc(ec); } else