diff --git a/examples/chat_server/chat_server.cpp b/examples/chat_server/chat_server.cpp index fbaca20b44..6c19452b8b 100644 --- a/examples/chat_server/chat_server.cpp +++ b/examples/chat_server/chat_server.cpp @@ -39,7 +39,9 @@ int main(int argc, char* argv[]) { short port = 5000; if (argc == 3) { + // TODO: input validation? port = atoi(argv[2]); + std::stringstream temp; temp << argv[1] << ":" << port; @@ -57,6 +59,17 @@ int main(int argc, char* argv[]) { new websocketpp::server(io_service,endpoint,host,chat_handler) ); + // setup server settings + server->add_host(host); + + // Chat server should only be receiving small text messages, reduce max + // message size limit slightly to save memory, improve performance, and + // guard against DoS attacks. + server->set_max_message_size(0xFFFF); // 64KiB + + // start the server + server->start_accept(); + std::cout << "Starting chat server on " << host << std::endl; io_service.run(); diff --git a/examples/echo_server/echo_server b/examples/echo_server/echo_server index cdc048373f..42c197d31a 100755 Binary files a/examples/echo_server/echo_server and b/examples/echo_server/echo_server differ diff --git a/examples/echo_server/echo_server.cpp b/examples/echo_server/echo_server.cpp index 030d719907..c3ee009025 100644 --- a/examples/echo_server/echo_server.cpp +++ b/examples/echo_server/echo_server.cpp @@ -59,10 +59,19 @@ int main(int argc, char* argv[]) { new websocketpp::server(io_service,endpoint,echo_handler) ); + // setup server settings server->add_host(host); + // bump up max message size to maximum since we may be using the echo + // server to test performance and protocol extremes. + server->set_max_message_size(websocketpp::frame::PAYLOAD_64BIT_LIMIT); + + // start the server + server->start_accept(); + std::cout << "Starting echo server on " << host << std::endl; + // start asio io_service.run(); } catch (std::exception& e) { std::cerr << "Exception: " << e.what() << std::endl; diff --git a/src/websocket_server.cpp b/src/websocket_server.cpp index 997c4618db..ada8b01453 100644 --- a/src/websocket_server.cpp +++ b/src/websocket_server.cpp @@ -36,11 +36,10 @@ using websocketpp::server; server::server(boost::asio::io_service& io_service, const tcp::endpoint& endpoint, connection_handler_ptr defc) - : m_io_service(io_service), + : m_max_message_size(DEFAULT_MAX_MESSAGE_SIZE), + m_io_service(io_service), m_acceptor(io_service, endpoint), - m_def_con_handler(defc) { - this->start_accept(); -} + m_def_con_handler(defc) {} void server::add_host(std::string host) { m_hosts.insert(host); @@ -50,8 +49,47 @@ void server::remove_host(std::string host) { m_hosts.erase(host); } + +void server::set_max_message_size(uint64_t val) { + if (val > frame::PAYLOAD_64BIT_LIMIT) { + std::stringstream err; + err << "Invalid maximum message size: " << val; + + // TODO: Figure out what the ideal error behavior for this method. + // Options: + // Throw exception + // Log error and set value to maximum allowed + // Log error and leave value at whatever it was before + throw server_error(err.str()); + } + m_max_message_size = val; +} + +bool server::validate_host(std::string host) { + if (m_hosts.find(host) == m_hosts.end()) { + return false; + } + return true; +} + +bool server::validate_message_size(uint64_t val) { + if (val > m_max_message_size) { + return false; + } + return true; +} + +void server::error_log(std::string msg) { + std::cerr << "[Error Log] " << msg << std::endl; +} +void server::access_log(std::string msg) { + std::cout << "[Access Log] " << msg << std::endl; +} + void server::start_accept() { - session_ptr new_ws(new session(m_io_service,m_def_con_handler)); + session_ptr new_ws(new session(shared_from_this(), + m_io_service, + m_def_con_handler)); m_acceptor.async_accept( new_ws->socket(), @@ -68,15 +106,13 @@ void server::handle_accept(session_ptr session, const boost::system::error_code& error) { if (!error) { - // set up session - std::set::iterator it; - for (it = m_hosts.begin(); it != m_hosts.end(); it++) { - session->add_host(*it); - } - session->start(); } else { - std::cout << "Error" << std::endl; + std::stringstream err; + err << "Error accepting socket connection: " << error; + + error_log(err.str()); + throw server_error(err.str()); } this->start_accept(); diff --git a/src/websocket_server.hpp b/src/websocket_server.hpp index 4b71fcdbff..4821511825 100644 --- a/src/websocket_server.hpp +++ b/src/websocket_server.hpp @@ -28,47 +28,80 @@ #ifndef WEBSOCKET_SERVER_HPP #define WEBSOCKET_SERVER_HPP -#include "websocket_session.hpp" - #include #include #include +namespace websocketpp { + class server; + typedef boost::shared_ptr server_ptr; +} + +#include "websocketpp.hpp" +#include "websocket_session.hpp" +#include "websocket_connection_handler.hpp" + using boost::asio::ip::tcp; namespace websocketpp { -class server { +class server_error : public std::exception { +public: + server_error(const std::string& msg) + : m_msg(msg) {} + ~server_error() throw() {} + + virtual const char* what() const throw() { + return m_msg.c_str(); + } +private: + std::string m_msg; +}; + +class server : public boost::enable_shared_from_this { public: server(boost::asio::io_service& io_service, const tcp::endpoint& endpoint, connection_handler_ptr defc); + // creates a new session object and connects the next websocket + // connection to it. + void start_accept(); + // Add or remove a host string (host:port) to the list of acceptable // hosts to accept websocket connections from. Additions/deletions here // only affect new connections. void add_host(std::string host); void remove_host(std::string host); + + void set_max_message_size(uint64_t val); + + // Check if this server will respond to this host. + bool validate_host(std::string host); + + // Check if message size is within server's acceptable parameters + bool validate_message_size(uint64_t val); + + // write to the server's logs + void error_log(std::string msg); + void access_log(std::string msg); private: - // creates a new session object and connects the next websocket - // connection to it. - void start_accept(); + // if no errors starts the session's read loop and returns to the // start_accept phase. void handle_accept(session_ptr session, const boost::system::error_code& error); - + private: std::set m_hosts; + uint64_t m_max_message_size; boost::asio::io_service& m_io_service; tcp::acceptor m_acceptor; connection_handler_ptr m_def_con_handler; }; -typedef boost::shared_ptr server_ptr; - } #endif // WEBSOCKET_SERVER_HPP \ No newline at end of file diff --git a/src/websocket_session.cpp b/src/websocket_session.cpp index ca62cb4d51..9f4f47c5b9 100644 --- a/src/websocket_session.cpp +++ b/src/websocket_session.cpp @@ -40,10 +40,11 @@ using websocketpp::session; -session::session (boost::asio::io_service& io_service, +session::session (server_ptr s,boost::asio::io_service& io_service, connection_handler_ptr defc) - : m_socket(io_service), - m_status(CONNECTING), + : m_status(CONNECTING), + m_server(s), + m_socket(io_service), m_local_interface(defc) {} tcp::socket& session::socket() { @@ -51,7 +52,9 @@ tcp::socket& session::socket() { } void session::start() { - std::cout << "[Connection " << this << "] WebSocket Connection request from " << m_socket.remote_endpoint() << std::endl; + std::stringstream msg; + msg << "[Connection " << this << "] WebSocket Connection request from " << m_socket.remote_endpoint(); + m_server->access_log(msg.str()); // async read to handle_read_handshake boost::asio::async_read_until( @@ -75,21 +78,6 @@ void session::set_handler(connection_handler_ptr new_con) { m_local_interface->connect(shared_from_this()); } -void session::add_host(std::string host) { - m_hosts.insert(host); -} - -void session::remove_host(std::string host) { - m_hosts.erase(host); -} - -void session::set_max_message_size(uint64_t val) { - if (val > frame::PAYLOAD_64BIT_LIMIT) { - throw "illegal maximum message size"; - } - m_max_message_size = val; -} - std::string session::get_header(const std::string& key) const { std::map::const_iterator h = m_headers.find(key); @@ -280,7 +268,7 @@ void session::handle_read_handshake(const boost::system::error_code& e, h = get_header("Host"); if (h == "") { throw(handshake_error("Required Host header is missing",400)); - } else if (m_hosts.find(h) == m_hosts.end()) { + } else if (!m_server->validate_host(h)) { err << "Host " << h << " is not one of this server's names."; throw(handshake_error(err.str(),400)); } @@ -325,7 +313,10 @@ void session::handle_read_handshake(const boost::system::error_code& e, } } catch (const handshake_error& e) { - std::cerr << "Caught handshake exception: " << e.what() << std::endl; + std::stringstream err; + err << "Caught handshake exception: " << e.what(); + + m_server->error_log(err.str()); e.write(shared_from_this()); return; } @@ -345,7 +336,7 @@ void session::write_handshake() { sha << server_key.c_str(); if (!sha.Result(message_digest)) { - std::cerr << "Error computing sha1 hash, killing connection." << std::endl; + m_server->error_log("Error computing handshake sha1 hash."); write_http_error(500,""); return; } @@ -384,7 +375,11 @@ void session::handle_write_handshake(const boost::system::error_code& error) { return; } - std::cout << "WebSocket Version " << m_version << " connection opened." << std::endl; + std::stringstream msg; + msg << "[Connection " << this << "] WebSocket Version " + << m_version << " connection opened."; + + m_server->access_log(msg.str()); m_status = OPEN; @@ -417,6 +412,12 @@ void session::write_http_error(int code,const std::string &msg) { boost::asio::placeholders::error ) ); + + std::stringstream err; + err << "Handshake ended with HTTP error: " << code << " " + << (msg != "" ? msg : lookup_http_error_string(code)); + + m_server->error_log(err.str()); } void session::handle_write_http_error(const boost::system::error_code& error) { diff --git a/src/websocket_session.hpp b/src/websocket_session.hpp index 76339a4534..c07ceff675 100644 --- a/src/websocket_session.hpp +++ b/src/websocket_session.hpp @@ -52,6 +52,7 @@ namespace websocketpp { class handshake_error; } +#include "websocket_server.hpp" #include "websocket_frame.hpp" #include "websocket_connection_handler.hpp" @@ -75,7 +76,9 @@ public: friend class handshake_error; - session (boost::asio::io_service& io_service, connection_handler_ptr defc); + session (server_ptr s, + boost::asio::io_service& io_service, + connection_handler_ptr defc); tcp::socket& socket(); @@ -97,17 +100,6 @@ public: /*** HANDSHAKE INTERFACE ***/ - // Add or remove a host string (host:port) to the list of acceptable - // hosts to accept websocket connections from. Additions/deletions here - // only affect only this connection. This list is initially populated based - // on values stored in the server. - void add_host(std::string host); - void remove_host(std::string host); - - // Sets the maximum allowed message size. Higher values will require larger - // memory buffers. Maximum value is 2^63, default value is ?? - void set_max_message_size(uint64_t); - // gets the value of a header or the empty string if not present. std::string get_header(const std::string &key) const; // adds an arbitrary header to the server handshake HTTP response. @@ -202,11 +194,6 @@ private: std::string lookup_http_error_string(int code); private: - // Connection specific constants (defaults to values from server, can be - // changed on a per connection basis if needed.) - std::set m_hosts; - uint64_t m_max_message_size; - // Immutable state about the current connection from the handshake std::string m_request; std::map m_headers; @@ -217,6 +204,7 @@ private: status_code m_status; // Connection Resources + server_ptr m_server; tcp::socket m_socket; connection_handler_ptr m_local_interface; diff --git a/src/websocketpp.hpp b/src/websocketpp.hpp index 7d28071906..82d4bc8ba6 100644 --- a/src/websocketpp.hpp +++ b/src/websocketpp.hpp @@ -30,6 +30,11 @@ #ifndef WEBSOCKETPP_HPP #define WEBSOCKETPP_HPP +// Defaults +namespace websocketpp { + const uint64_t DEFAULT_MAX_MESSAGE_SIZE = 0xFFFFFF; // ~16MB +} + #include "websocket_server.hpp" #endif // WEBSOCKETPP_HPP