From bc07943e79433389d83c8a2f74c1324d6d739636 Mon Sep 17 00:00:00 2001 From: JoelKatz Date: Thu, 24 Jan 2013 15:53:26 -0800 Subject: [PATCH] Moderate bastardization to support auto-TLS. --- src/cpp/websocketpp/src/connection.hpp | 81 +++++++++++---- src/cpp/websocketpp/src/endpoint.hpp | 4 +- src/cpp/websocketpp/src/roles/server.hpp | 121 +++++++++++++++++----- src/cpp/websocketpp/src/sockets/plain.hpp | 2 + src/cpp/websocketpp/src/sockets/tls.hpp | 2 + src/cpp/websocketpp/src/websocketpp.hpp | 4 + 6 files changed, 164 insertions(+), 50 deletions(-) diff --git a/src/cpp/websocketpp/src/connection.hpp b/src/cpp/websocketpp/src/connection.hpp index 53f37a559..9883148ac 100644 --- a/src/cpp/websocketpp/src/connection.hpp +++ b/src/cpp/websocketpp/src/connection.hpp @@ -898,20 +898,39 @@ public: !m_protocol_error) { // TODO: read timeout timer? - - boost::asio::async_read( - socket_type::get_socket(), - m_buf, - boost::asio::transfer_at_least(std::min( - m_read_threshold, - static_cast(m_processor->get_bytes_needed()) - )), - m_strand.wrap(boost::bind( - &type::handle_read_frame, - type::shared_from_this(), - boost::asio::placeholders::error - )) - ); + + if (socket_type::get_socket().isSecure()) + { + boost::asio::async_read( + socket_type::get_socket().SSLSocket(), + m_buf, + boost::asio::transfer_at_least(std::min( + m_read_threshold, + static_cast(m_processor->get_bytes_needed()) + )), + m_strand.wrap(boost::bind( + &type::handle_read_frame, + type::shared_from_this(), + boost::asio::placeholders::error + )) + ); + } + else + { + boost::asio::async_read( + socket_type::get_socket().PlainSocket(), + m_buf, + boost::asio::transfer_at_least(std::min( + m_read_threshold, + static_cast(m_processor->get_bytes_needed()) + )), + m_strand.wrap(boost::bind( + &type::handle_read_frame, + type::shared_from_this(), + boost::asio::placeholders::error + )) + ); + } } } public: @@ -1209,15 +1228,31 @@ public: //m_endpoint.alog().at(log::alevel::DEVEL) << "write header: " << zsutil::to_hex(m_write_queue.front()->get_header()) << log::endl; - boost::asio::async_write( - socket_type::get_socket(), - m_write_buf, - m_strand.wrap(boost::bind( - &type::handle_write, - type::shared_from_this(), - boost::asio::placeholders::error - )) - ); + if (socket_type::get_socket().isSecure()) + { + boost::asio::async_write( + socket_type::get_socket().SSLSocket(), + m_write_buf, + m_strand.wrap(boost::bind( + &type::handle_write, + type::shared_from_this(), + boost::asio::placeholders::error + )) + ); + } + else + { + boost::asio::async_write( + socket_type::get_socket().PlainSocket(), + m_write_buf, + m_strand.wrap(boost::bind( + &type::handle_write, + type::shared_from_this(), + boost::asio::placeholders::error + )) + ); + } + } else { // if we are in an inturrupted state and had nothing else to write // it is safe to terminate the connection. diff --git a/src/cpp/websocketpp/src/endpoint.hpp b/src/cpp/websocketpp/src/endpoint.hpp index f9ac5a582..862243c1c 100644 --- a/src/cpp/websocketpp/src/endpoint.hpp +++ b/src/cpp/websocketpp/src/endpoint.hpp @@ -29,7 +29,7 @@ #define WEBSOCKETPP_ENDPOINT_HPP #include "connection.hpp" -#include "sockets/plain.hpp" // should this be here? +#include "sockets/autotls.hpp" // should this be here? #include "logger/logger.hpp" #include @@ -74,7 +74,7 @@ protected: */ template < template class role, - template class socket = socket::plain, + template class socket = socket::autotls, template class logger = log::logger> class endpoint : public endpoint_base, diff --git a/src/cpp/websocketpp/src/roles/server.hpp b/src/cpp/websocketpp/src/roles/server.hpp index 270a478a0..7ee35ca32 100644 --- a/src/cpp/websocketpp/src/roles/server.hpp +++ b/src/cpp/websocketpp/src/roles/server.hpp @@ -53,6 +53,42 @@ namespace websocketpp { +typedef boost::asio::buffers_iterator bufIterator; + +static std::pair match_header(bufIterator begin, bufIterator end) +{ + static const std::string eol_match = "\n"; + static const std::string header_match = "\n\r\n"; + static const std::string alt_header_match = "\n\n"; + static const std::string flash_match = ""; + + // Do we have a complete HTTP request + bufIterator it = std::search(begin, end, header_match.begin(), header_match.end()); + if (it != end) + it += header_match.size() - 1; + else + { + it = std::search(begin, end, alt_header_match.begin(), alt_header_match.end()); + if (it != end) + it += alt_header_match.size() - 1; + } + if (it != end) // Yes, this is HTTP + return std::make_pair(it, true); + + // If we don't have a flash policy request, we're done + it = std::search(begin, end, flash_match.begin(), flash_match.end()); + if (it == end) // No match + return std::make_pair(end, false); + + // If we have a line ending before the flash policy request, treat as http + bufIterator it2 = std::search(begin, end, eol_match.begin(), eol_match.end()); + if ((it2 != end) || (it < it2)) + return std::make_pair(end, false); + + // Treat as flash policy request + return std::make_pair(it + flash_match.size() - 1, true); +} + // Forward declarations template struct endpoint_traits; @@ -513,18 +549,37 @@ void server::connection::async_init() { // TODO: make this value configurable m_connection.register_timeout(5000,fail::status::TIMEOUT_WS, "Timeout on WebSocket handshake"); - - boost::asio::async_read_until( - m_connection.get_socket(), - m_connection.buffer(), - "\r\n\r\n", - m_connection.get_strand().wrap(boost::bind( - &type::handle_read_request, - m_connection.shared_from_this(), - boost::asio::placeholders::error, - boost::asio::placeholders::bytes_transferred - )) - ); + + if (m_connection.get_socket().isSecure()) + { + boost::asio::async_read_until( + m_connection.get_socket().SSLSocket(), + m_connection.buffer(), +// match_header, + "\r\n\r\n", + m_connection.get_strand().wrap(boost::bind( + &type::handle_read_request, + m_connection.shared_from_this(), + boost::asio::placeholders::error, + boost::asio::placeholders::bytes_transferred + )) + ); + } + else + { + boost::asio::async_read_until( + m_connection.get_socket().PlainSocket(), + m_connection.buffer(), +// match_header, + "\r\n\r\n", + m_connection.get_strand().wrap(boost::bind( + &type::handle_read_request, + m_connection.shared_from_this(), + boost::asio::placeholders::error, + boost::asio::placeholders::bytes_transferred + )) + ); + } } /// processes the response from an async read for an HTTP header @@ -673,9 +728,9 @@ void server::connection::handle_read_request( { // TODO: this makes the assumption that WS and HTTP // default ports are the same. - m_uri.reset(new uri(m_endpoint.is_secure(),h,m_request.uri())); + m_uri.reset(new uri(m_connection.is_secure(),h,m_request.uri())); } else { - m_uri.reset(new uri(m_endpoint.is_secure(), + m_uri.reset(new uri(m_connection.is_secure(), h.substr(0,last_colon), h.substr(last_colon+1), m_request.uri())); @@ -795,17 +850,33 @@ void server::connection::write_response() { shared_const_buffer buffer(raw); m_endpoint.m_alog->at(log::alevel::DEBUG_HANDSHAKE) << raw << log::endl; - - boost::asio::async_write( - m_connection.get_socket(), - //boost::asio::buffer(raw), - buffer, - boost::bind( - &type::handle_write_response, - m_connection.shared_from_this(), - boost::asio::placeholders::error - ) - ); + + if (m_connection.get_socket().isSecure()) + { + boost::asio::async_write( + m_connection.get_socket().SSLSocket(), + //boost::asio::buffer(raw), + buffer, + boost::bind( + &type::handle_write_response, + m_connection.shared_from_this(), + boost::asio::placeholders::error + ) + ); + } + else + { + boost::asio::async_write( + m_connection.get_socket().PlainSocket(), + //boost::asio::buffer(raw), + buffer, + boost::bind( + &type::handle_write_response, + m_connection.shared_from_this(), + boost::asio::placeholders::error + ) + ); + } } template diff --git a/src/cpp/websocketpp/src/sockets/plain.hpp b/src/cpp/websocketpp/src/sockets/plain.hpp index 48aa91079..c7ecaa96b 100644 --- a/src/cpp/websocketpp/src/sockets/plain.hpp +++ b/src/cpp/websocketpp/src/sockets/plain.hpp @@ -25,6 +25,8 @@ * */ +#error Use Auto TLS only + #ifndef WEBSOCKETPP_SOCKET_PLAIN_HPP #define WEBSOCKETPP_SOCKET_PLAIN_HPP diff --git a/src/cpp/websocketpp/src/sockets/tls.hpp b/src/cpp/websocketpp/src/sockets/tls.hpp index 156901a4e..34e7af8ca 100644 --- a/src/cpp/websocketpp/src/sockets/tls.hpp +++ b/src/cpp/websocketpp/src/sockets/tls.hpp @@ -25,6 +25,8 @@ * */ +#error Use auto TLS only + #ifndef WEBSOCKETPP_SOCKET_TLS_HPP #define WEBSOCKETPP_SOCKET_TLS_HPP diff --git a/src/cpp/websocketpp/src/websocketpp.hpp b/src/cpp/websocketpp/src/websocketpp.hpp index 221d66e6f..c6692a9ba 100644 --- a/src/cpp/websocketpp/src/websocketpp.hpp +++ b/src/cpp/websocketpp/src/websocketpp.hpp @@ -41,6 +41,10 @@ namespace websocketpp { typedef websocketpp::endpoint server_tls; #endif + #ifdef WEBSOCKETPP_SOCKET_AUTOTLS_HPP + typedef websocketpp::endpoint server_autotls; + #endif #endif