//------------------------------------------------------------------------------ /* Copyright (c) 2011-2013, OpenCoin, Inc. */ //============================================================================== #ifndef __AUTOSOCKET_H_ #define __AUTOSOCKET_H_ // Socket wrapper that supports both SSL and non-SSL connections. // Generally, handle it as you would an SSL connection. // To force a non-SSL connection, just don't call async_handshake. // To force SSL only inbound, call setSSLOnly. class AutoSocket { public: typedef boost::asio::ssl::stream ssl_socket; typedef boost::shared_ptr socket_ptr; typedef ssl_socket::next_layer_type plain_socket; typedef ssl_socket::lowest_layer_type lowest_layer_type; typedef ssl_socket::handshake_type handshake_type; typedef boost::system::error_code error_code; typedef boost::function callback; public: AutoSocket (boost::asio::io_service& s, boost::asio::ssl::context& c) : mSecure (false) , mBuffer (4) { mSocket = boost::make_shared (boost::ref (s), boost::ref (c)); } AutoSocket (boost::asio::io_service& s, boost::asio::ssl::context& c, bool secureOnly, bool plainOnly) : mSecure (secureOnly) , mBuffer ((plainOnly || secureOnly) ? 0 : 4) { mSocket = boost::make_shared (boost::ref (s), boost::ref (c)); } boost::asio::io_service& get_io_service () noexcept { return mSocket->get_io_service (); } bool isSecure () { return mSecure; } ssl_socket& SSLSocket () { return *mSocket; } plain_socket& PlainSocket () { return mSocket->next_layer (); } void setSSLOnly () { mSecure = true; } void setPlainOnly () { mBuffer.clear (); } lowest_layer_type& lowest_layer () { return mSocket->lowest_layer (); } void swap (AutoSocket& s) { mBuffer.swap (s.mBuffer); mSocket.swap (s.mSocket); std::swap (mSecure, s.mSecure); } boost::system::error_code cancel (boost::system::error_code& ec) { return lowest_layer ().cancel (ec); } static bool rfc2818_verify (const std::string& domain, bool preverified, boost::asio::ssl::verify_context& ctx) { using namespace ripple; if (boost::asio::ssl::rfc2818_verification (domain) (preverified, ctx)) return true; Log (lsWARNING, AutoSocketPartition) << "Outbound SSL connection to " << domain << " fails certificate verification"; return false; } boost::system::error_code verify (const std::string& strDomain) { boost::system::error_code ec; mSocket->set_verify_mode (boost::asio::ssl::verify_peer); // XXX Verify semantics of RFC 2818 are what we want. mSocket->set_verify_callback (boost::bind (&rfc2818_verify, strDomain, boost::placeholders::_1, boost::placeholders::_2), ec); return ec; } template BOOST_ASIO_INITFN_RESULT_TYPE(HandshakeHandler, void (boost::system::error_code)) async_handshake (handshake_type role, BOOST_ASIO_MOVE_ARG(HandshakeHandler) handler) { return async_handshake_cb (role, handler); } void async_handshake_cb (handshake_type type, callback cbFunc) { if ((type == ssl_socket::client) || (mSecure)) { // must be ssl mSecure = true; mSocket->async_handshake (type, cbFunc); } else if (mBuffer.empty ()) { // must be plain mSecure = false; mSocket->get_io_service ().post (boost::bind (cbFunc, error_code ())); } else { // autodetect mSocket->next_layer ().async_receive (boost::asio::buffer (mBuffer), boost::asio::socket_base::message_peek, boost::bind (&AutoSocket::handle_autodetect, this, cbFunc, boost::asio::placeholders::error, boost::asio::placeholders::bytes_transferred)); } } template void async_shutdown (ShutdownHandler handler) { if (isSecure ()) mSocket->async_shutdown (handler); else { error_code ec; try { lowest_layer ().shutdown (plain_socket::shutdown_both); } catch (boost::system::system_error& e) { ec = e.code(); } mSocket->get_io_service ().post (boost::bind (handler, ec)); } } template void async_read_some (const Seq& buffers, Handler handler) { if (isSecure ()) mSocket->async_read_some (buffers, handler); else PlainSocket ().async_read_some (buffers, handler); } template void async_read_until (const Seq& buffers, Condition condition, Handler handler) { if (isSecure ()) boost::asio::async_read_until (*mSocket, buffers, condition, handler); else boost::asio::async_read_until (PlainSocket (), buffers, condition, handler); } template void async_read_until (boost::asio::basic_streambuf& buffers, const std::string& delim, Handler handler) { if (isSecure ()) boost::asio::async_read_until (*mSocket, buffers, delim, handler); else boost::asio::async_read_until (PlainSocket (), buffers, delim, handler); } template void async_read_until (boost::asio::basic_streambuf& buffers, MatchCondition cond, Handler handler) { if (isSecure ()) boost::asio::async_read_until (*mSocket, buffers, cond, handler); else boost::asio::async_read_until (PlainSocket (), buffers, cond, handler); } template void async_write (const Buf& buffers, Handler handler) { if (isSecure ()) boost::asio::async_write (*mSocket, buffers, handler); else boost::asio::async_write (PlainSocket (), buffers, handler); } template void async_write (boost::asio::basic_streambuf& buffers, Handler handler) { if (isSecure ()) boost::asio::async_write (*mSocket, buffers, handler); else boost::asio::async_write (PlainSocket (), buffers, handler); } template void async_read (const Buf& buffers, Condition cond, Handler handler) { if (isSecure ()) boost::asio::async_read (*mSocket, buffers, cond, handler); else boost::asio::async_read (PlainSocket (), buffers, cond, handler); } template void async_read (boost::asio::basic_streambuf& buffers, Condition cond, Handler handler) { if (isSecure ()) boost::asio::async_read (*mSocket, buffers, cond, handler); else boost::asio::async_read (PlainSocket (), buffers, cond, handler); } template void async_read (const Buf& buffers, Handler handler) { if (isSecure ()) boost::asio::async_read (*mSocket, buffers, handler); else boost::asio::async_read (PlainSocket (), buffers, handler); } template void async_write_some (const Seq& buffers, Handler handler) { if (isSecure ()) mSocket->async_write_some (buffers, handler); else PlainSocket ().async_write_some (buffers, handler); } protected: void handle_autodetect (callback cbFunc, const error_code& ec, size_t bytesTransferred) { using namespace ripple; if (ec) { Log (lsWARNING, AutoSocketPartition) << "Handle autodetect error: " << ec; cbFunc (ec); } else if ((mBuffer[0] < 127) && (mBuffer[0] > 31) && ((bytesTransferred < 2) || ((mBuffer[1] < 127) && (mBuffer[1] > 31))) && ((bytesTransferred < 3) || ((mBuffer[2] < 127) && (mBuffer[2] > 31))) && ((bytesTransferred < 4) || ((mBuffer[3] < 127) && (mBuffer[3] > 31)))) { // not ssl if (AutoSocketPartition.doLog (lsTRACE)) Log (lsTRACE, AutoSocketPartition) << "non-SSL"; mSecure = false; cbFunc (ec); } else { // ssl if (AutoSocketPartition.doLog (lsTRACE)) Log (lsTRACE, AutoSocketPartition) << "SSL"; mSecure = true; mSocket->async_handshake (ssl_socket::server, cbFunc); } } private: static ripple::LogPartition AutoSocketPartition; socket_ptr mSocket; bool mSecure; std::vector mBuffer; }; #endif